[PATCH 2/3][Autotest][virt] autotest.common_lib: Add syncdata class to common_lib

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



Syncdata class allows to synchronize data between
multiple hosts and guests. It creates a server for
syncing data and hosts uses the server for data
synchronization.

Signed-off-by: Jiří Župka <jzupka@xxxxxxxxxx>
---
 client/common_lib/base_barrier.py           |    2 +-
 client/common_lib/base_syncdata.py          |  271 +++++++++++++++++++++++++++
 client/common_lib/base_syncdata_unittest.py |  203 ++++++++++++++++++++
 client/common_lib/error.py                  |   10 +
 client/common_lib/syncdata.py               |   15 ++
 5 files changed, 500 insertions(+), 1 deletions(-)
 create mode 100644 client/common_lib/base_syncdata.py
 create mode 100755 client/common_lib/base_syncdata_unittest.py
 create mode 100644 client/common_lib/syncdata.py

diff --git a/client/common_lib/base_barrier.py b/client/common_lib/base_barrier.py
index d20916a..df4da49 100644
--- a/client/common_lib/base_barrier.py
+++ b/client/common_lib/base_barrier.py
@@ -50,7 +50,7 @@ class listen_server(object):
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
         sock.bind((self.address, self.port))
-        sock.listen(10)
+        sock.listen(100)
 
         return sock
 
diff --git a/client/common_lib/base_syncdata.py b/client/common_lib/base_syncdata.py
new file mode 100644
index 0000000..c50b6cb
--- /dev/null
+++ b/client/common_lib/base_syncdata.py
@@ -0,0 +1,271 @@
+import pickle, time, socket, errno, threading, logging, signal
+from autotest_lib.client.common_lib import error
+from autotest_lib.client.common_lib import barrier
+from autotest_lib.client.common_lib import utils
+from autotest_lib.client.bin import parallel
+
+_DEFAULT_PORT = 13234
+_DEFAULT_TIMEOUT = 10
+
+
+def net_send_object(sock, obj):
+    """
+    Send python object over network.
+
+    @param ip_addr: ipaddres of waiter for data.
+    @param obj: object to send
+    """
+    data = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
+    sock.sendall("%6d" % (len(data)))
+    sock.sendall(data)
+
+
+def net_recv_object(sock, timeout=60):
+    """
+    Receive python object over network.
+
+    @param ip_addr: ipaddres of waiter for data.
+    @param obj: object to send
+    @return: object from network
+    """
+    try:
+        endtime = time.time() + timeout
+        data = ""
+        d_len = int(sock.recv(6))
+
+        while (len(data) < d_len and (time.time() <= endtime)):
+            data += sock.recv(d_len - len(data))
+        if (time.time() > endtime):
+            raise error.NetCommunicationError("Connection timeout.")
+        data = pickle.loads(data)
+        return data
+    except (socket.timeout, ValueError), e:
+        raise error.NetCommunicationError("Failed to receive python"
+                                          " object over the network.")
+
+
+class SessionData(object):
+    def __init__(self, hosts, timeout):
+        self.hosts = hosts
+        self.endtime = time.time() + timeout
+        self.sync_data = {}
+        self.connection = {}
+        self.data_lock = threading.Lock()
+        self.data_recv = 0
+        self.finished = False
+
+    def remaining(self):
+        remaining = self.endtime - time.time()
+        if remaining < 0:
+            remaining = 0
+        return remaining
+
+    def close(self):
+        for connection in self.connection.values():
+            connection[0].close()
+
+
+class SyncListenServer(object):
+    def __init__(self, tmpdir, address='', port=_DEFAULT_PORT):
+        """
+        @param address: Address on which server must be started.
+        @param port: Port of server.
+        @param tmpdir: Dir where pid file is saved.
+        """
+        l = lambda: self._start_server(address, port)
+
+        self.tmpdir = tmpdir
+        self.sessions = {}
+        self.exit_event = threading.Event()
+
+        self.server_pid = parallel.fork_start(self.tmpdir, l)
+
+    def _clean_sessions(self):
+        """
+        Close and delete timed-out connection.
+        """
+        to_del = []
+        for session_id, session in self.sessions.items():
+            if session.data_lock.acquire(False):
+                if ((not session.finished and not session.remaining()) or
+                    session.finished):
+                    if not session.finished:
+                        logging.warn("Sync Session %s timeout." %
+                                     (session.hosts))
+                    session.close()
+                    to_del.append(session_id)
+                session.data_lock.release()
+        for td in to_del:
+            del(self.sessions[td])
+
+    def _recv_data(self, connection, session):
+        session.data_lock.acquire()
+        client, addr = connection
+        session.connection[addr[0]] = connection
+
+        try:
+            logging.info("Try recv from client")
+            session.sync_data[addr[0]] = net_recv_object(client,
+                                                         _DEFAULT_TIMEOUT)
+            session.data_recv += 1
+        except socket.timeout:
+            raise error.DataSyncError("Fail to communicate with client"
+                                      " %s. Synchronization of data "
+                                      "is not possible" % (addr))
+        except error.NetCommunicationError:
+            pass
+
+        if not session.finished:
+            if (session.data_recv == len(session.hosts) and
+                session.remaining()):
+                for client, _ in session.connection.values():
+                    net_send_object(client, session.sync_data)
+                    net_recv_object(client, _DEFAULT_TIMEOUT)
+                session.finished = True
+        session.data_lock.release()
+
+    def __call__(self, signum, frame):
+        self.exit_event.set()
+
+    def _start_server(self, address, port):
+        signal.signal(signal.SIGTERM, self)
+        self.server_thread = utils.InterruptedThread(self._server,
+                                                (address, port))
+        self.server_thread.start()
+
+        while not self.exit_event.is_set():
+            signal.pause()
+
+        self.server_thread.join(2 * _DEFAULT_TIMEOUT)
+        for session in self.sessions.itervalues():
+            session.close()
+        self.listen_server.close()
+
+    def _server(self, address, port):
+        self.listen_server = barrier.listen_server(address, port)
+        logging.debug("Wait for clients")
+        self.listen_server.socket.settimeout(_DEFAULT_TIMEOUT)
+        while not self.exit_event.is_set():
+            try:
+                connection = self.listen_server.socket.accept()
+                logging.debug("Client %s connected.", connection[1][0])
+                session_id, hosts, timeout = net_recv_object(connection[0],
+                                                             _DEFAULT_TIMEOUT)
+                self._clean_sessions()
+                if not session_id in self.sessions:
+                    logging.debug("Add new session")
+                    self.sessions[session_id] = SessionData(hosts, timeout)
+
+                utils.InterruptedThread(self._recv_data, (connection,
+                                        self.sessions[session_id])).start()
+
+            except (socket.timeout, error.NetCommunicationError):
+                self._clean_sessions()
+
+    def close(self):
+        """
+        Close SyncListenServer thread. Close listen server. And close all
+        unclosed connection with clients.
+        """
+        utils.signal_pid(self.server_pid, signal.SIGTERM)
+        if utils.pid_is_alive(self.server_pid):
+            parallel.fork_waitfor_timed(self.tmpdir, self.server_pid,
+                                        2 * _DEFAULT_TIMEOUT)
+
+        logging.debug("SyncListenServer was killed.")
+
+
+class SyncData(object):
+    """
+    Provides data synchronization between hosts. Transferred data
+    are pickled and sent to all destination.
+       If there is no listen server it create new one. If multiple hosts
+    wants communicate with each other then communications are identified
+    by session_id.
+    """
+    def __init__(self, masterid, hostid, hosts, session_id=None,
+                 listen_server=None, port=13234, tmpdir=None):
+        self.port = port
+        self.hosts = hosts
+        self.session_id = session_id
+        self.endtime = None
+
+        self.hostid = hostid
+        self.masterid = masterid
+        self.master = self.hostid == self.masterid
+        self.connection = []
+        self.server = None
+        self.killserver = False
+
+        self.listen_server = listen_server
+        if not self.listen_server and self.master:
+            if tmpdir is None:
+                raise error.DataSyncError("Tmpdir can not be None.")
+            self.listen_server = SyncListenServer(tmpdir, port=self.port)
+            self.killserver = True
+
+        self.sync_data = {}
+
+    def close(self):
+        if self.killserver:
+            self.listen_server.close()
+
+    def _remaining(self):
+        remaining = self.endtime - time.time()
+        if remaining < 0:
+            remaining = 0
+        return remaining
+
+    def _client(self, data, session_id, timeout):
+        if session_id is None:
+            session_id = self.session_id
+        session_id = str(session_id)
+        self.endtime = time.time() + timeout
+        logging.info("calling master: %s", self.hosts[0])
+        while self._remaining():
+            try:
+                self.server = socket.socket(socket.AF_INET,
+                                            socket.SOCK_STREAM)
+                self.server.settimeout(5)
+                self.server.connect((self.masterid, self.port))
+                self.server.settimeout(self._remaining())
+                net_send_object(self.server, (session_id, self.hosts,
+                                              self._remaining()))
+
+                net_send_object(self.server, data)
+                self.sync_data = net_recv_object(self.server,
+                                                 self._remaining())
+                net_send_object(self.server, "BYE")
+                break
+            except error.NetCommunicationError:
+                logging.warn("Problem with communication with server.")
+                self.server.close()
+                self.server = None
+                time.sleep(1)
+            except socket.timeout:
+                logging.warn("timeout calling host %s, retry" %
+                             (self.masterid))
+                time.sleep(1)
+            except socket.error, err:
+                (code, _) = err
+                if (code != errno.ECONNREFUSED):
+                    raise
+                time.sleep(1)
+        if not self._remaining():
+            raise error.DataSyncError("Timeout during data sync with data %s" %
+                                      (data))
+
+    def sync(self, data=None, timeout=60, session_id=None):
+        try:
+            self._client(data, session_id, timeout)
+        finally:
+            if self.server:
+                self.server.close()
+        return self.sync_data
+
+    def single_sync(self, data=None, timeout=60, session_id=None):
+        try:
+            self.sync(data, timeout, session_id)
+        finally:
+            self.close()
+        return self.sync_data
diff --git a/client/common_lib/base_syncdata_unittest.py b/client/common_lib/base_syncdata_unittest.py
new file mode 100755
index 0000000..b19393c
--- /dev/null
+++ b/client/common_lib/base_syncdata_unittest.py
@@ -0,0 +1,203 @@
+#!/usr/bin/python
+
+__author__ = """Jiri Zupka (jzupka@xxxxxxxxxx)"""
+
+import unittest
+import socket, threading, time, pickle, os
+try:
+    import autotest.common as common
+except ImportError:
+    import common
+from autotest_lib.client.common_lib.test_utils import mock
+from autotest_lib.client.common_lib import error, base_syncdata, barrier
+syncdata = base_syncdata
+
+
+class Test(unittest.TestCase):
+
+    def setUp(self):
+        self.god = mock.mock_god()
+        self.god.mock_io()
+
+    def tearDown(self):
+        self.god.unmock_io()
+
+    def test_send_receive_net_object(self):
+        ls = barrier.listen_server(port=syncdata._DEFAULT_PORT)
+
+        send_data = {'aa': ['bb', 'xx', ('ss')]}
+        server = self._start_server(ls, send_data)
+
+        recv_data = self._client("127.0.0.1", 10)
+        server.join()
+        ls.close()
+
+        self.assertEqual(recv_data, send_data)
+
+    def test_send_receive_net_object_close_connection(self):
+        ls = barrier.listen_server(port=syncdata._DEFAULT_PORT)
+
+        server = self._start_server(ls)
+
+        self.assertRaisesRegexp(error.NetCommunicationError,
+                                "Failed to receive python"
+                                " object over the network.",
+                                self._client, "127.0.0.1", 2)
+        server.join()
+        ls.close()
+
+    def test_send_receive_net_object_timeout(self):
+        ls = barrier.listen_server(port=syncdata._DEFAULT_PORT)
+
+        server = self._start_server(ls, timewait=5)
+
+        self.assertRaisesRegexp(error.NetCommunicationError,
+                                "Failed to receive python"
+                                " object over the network.",
+                                self._client, "127.0.0.1", 2)
+        server.join()
+        ls.close()
+
+    def test_send_receive_net_object_timeout_in_communication(self):
+        ls = barrier.listen_server(port=syncdata._DEFAULT_PORT)
+
+        send_data = {'aa': ['bb', 'xx', ('ss')]}
+        server = self._start_server(ls, send_data,
+                                    timewait=5, connbreak=True)
+
+        self.assertRaisesRegexp(error.NetCommunicationError,
+                                "Connection timeout.",
+                                 self._client, "127.0.0.1", 2)
+        server.join()
+        ls.close()
+
+    def test_SyncListenServer_start_close(self):
+        sync_ls = syncdata.SyncListenServer("/tmp/")
+        os.kill(sync_ls.server_pid, 0)
+        time.sleep(2)
+        sync_ls.close()
+        l = lambda : os.kill(sync_ls.server_pid, 0)
+        self.assertRaises(OSError, l)
+
+    def test_SyncData_tmp_missing(self):
+        self.assertRaisesRegexp(error.DataSyncError,
+                                "Tmpdir can not be None.",
+                                syncdata.SyncData, "127.0.0.1", "127.0.0.1",
+                                ["127.0.0.1"], "127.0.0.1#1")
+
+    def test_SyncData_with_listenServer(self):
+        sync_ls = syncdata.SyncListenServer("/tmp/")
+        sync = syncdata.SyncData("127.0.0.1", "127.0.0.1", ["127.0.0.1"],
+                                 "127.0.0.1#1", sync_ls)
+        data = sync.sync("test1")
+        sync.close()
+        sync_ls.close()
+        self.assertEqual(data, {'127.0.0.1': 'test1'})
+
+    def test_SyncData_with_self_listen_server(self):
+        sync = syncdata.SyncData("127.0.0.1", "127.0.0.1", ["127.0.0.1"],
+                                 "127.0.0.1#1", tmpdir="/tmp/")
+        os.kill(sync.listen_server.server_pid, 0)
+        data = sync.sync("test2")
+        sync.close()
+        l = lambda : os.kill(sync.listen_server.server_pid, 0)
+        self.assertRaises(OSError, l)
+        self.assertEqual(data, {'127.0.0.1': 'test2'})
+
+    def test_SyncData_with_listenServer_auto_close(self):
+        sync = syncdata.SyncData("127.0.0.1", "127.0.0.1", ["127.0.0.1"],
+                                 "127.0.0.1#1", tmpdir="/tmp/")
+        os.kill(sync.listen_server.server_pid, 0)
+        data = sync.single_sync("test3")
+        l = lambda: os.kill(sync.listen_server.server_pid, 0)
+        self.assertRaises(OSError, l)
+        self.assertEqual(data, {'127.0.0.1': 'test3'})
+
+    def test_SyncData_with_closed_listenServer(self):
+        sync_ls = syncdata.SyncListenServer("/tmp/")
+        sync_ls.close()
+        time.sleep(2)
+        sync = syncdata.SyncData("127.0.0.1", "127.0.0.1", ["127.0.0.1"],
+                                 "127.0.0.1#1", sync_ls)
+
+        l = lambda: sync.sync("test1", 2)
+        self.assertRaises(error.DataSyncError, l)
+
+    class MockListenServer(syncdata.SyncListenServer):
+        def _server(self):
+            self.listen_server = barrier.listen_server()
+            self.exit_event.wait()
+
+    def test_SyncData_with_listenServer_client_wait_timeout(self):
+        sync = syncdata.SyncData("127.0.0.1", "127.0.0.1",
+                                 ["127.0.0.1", "192.168.0.1"],
+                                 "127.0.0.1#1", tmpdir="/tmp/")
+
+        l = lambda: sync.single_sync("test1", 2)
+        self.assertRaises(error.DataSyncError, l)
+
+    def test_SyncData_with_listenServer_fake_server(self):
+        sync = syncdata.SyncData("127.0.0.1", "127.0.0.1",
+                                 ["127.0.0.1", "192.168.0.1"],
+                                 "127.0.0.1#1", self.MockListenServer("/tmp/"))
+
+        l = lambda: sync.single_sync("test1", 10)
+        self.assertRaises(error.DataSyncError, l)
+
+    def test_SyncData_multiple_session(self):
+        data_check = {}
+        threads = []
+        sync_ls = syncdata.SyncListenServer("/tmp/")
+        def _client_session_thread(sync_ls, data_check, id):
+            sync = syncdata.SyncData("127.0.0.1", "127.0.0.1", ["127.0.0.1"],
+                                 "127.0.0.1#%d" % id, sync_ls)
+            data_check[id] = sync.sync("test%d" % (id))
+            sync.close()
+
+        for id in range(30):
+            server_thread = threading.Thread(target=_client_session_thread,
+                                             args=(sync_ls, data_check, id))
+            threads.append(server_thread)
+            server_thread.start()
+
+        for th in threads:
+            th.join()
+
+        sync_ls.close()
+        for id in range(30):
+            self.assertEqual(data_check[id], {'127.0.0.1': 'test%d' % (id)})
+
+    def _start_server(self, listen_server, obj=None, timewait=None,
+                      connbreak=False):
+        def _server_thread(listen_server, obj=None, timewait=None,
+                           connbreak=False):
+            sock = listen_server.socket.accept()[0]
+            if not connbreak:
+                if timewait is not None:
+                    time.sleep(timewait)
+                if obj is not None:
+                    syncdata.net_send_object(sock, obj)
+            else:
+                data = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
+                sock.sendall("%6d" % len(data))
+                for _ in range(timewait):
+                    time.sleep(1)
+                    sock.sendall(".")
+            sock.close()
+
+        server_thread = threading.Thread(target=_server_thread,
+                                         args=(listen_server, obj,
+                                               timewait, connbreak))
+        server_thread.start()
+        return server_thread
+
+    def _client(self, addr, timeout):
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.connect((addr, syncdata._DEFAULT_PORT))
+        obj = syncdata.net_recv_object(sock, timeout)
+        sock.close()
+        return obj
+
+if __name__ == "__main__":
+    #import sys;sys.argv = ['', 'Test.testName']
+    unittest.main()
diff --git a/client/common_lib/error.py b/client/common_lib/error.py
index 1823143..aac3c5a 100644
--- a/client/common_lib/error.py
+++ b/client/common_lib/error.py
@@ -297,6 +297,16 @@ class BarrierAbortError(BarrierError):
     pass
 
 
+class NetCommunicationError(JobError):
+    """Indicate that network communication was broken."""
+    pass
+
+
+class DataSyncError(NetCommunicationError):
+    """Indicates problem during synchronization data over network."""
+    pass
+
+
 class InstallError(JobError):
     """Indicates an installation error which Terminates and fails the job."""
     pass
diff --git a/client/common_lib/syncdata.py b/client/common_lib/syncdata.py
new file mode 100644
index 0000000..922fa4c
--- /dev/null
+++ b/client/common_lib/syncdata.py
@@ -0,0 +1,15 @@
+from autotest_lib.client.common_lib.base_syncdata import SyncData
+from autotest_lib.client.common_lib.base_syncdata import SyncListenServer
+from autotest_lib.client.common_lib.base_syncdata import net_send_object
+from autotest_lib.client.common_lib.base_syncdata import net_recv_object
+from autotest_lib.client.common_lib import utils
+
+_SITE_MODULE_NAME = 'autotest_lib.client.common_lib.site_syncdata'
+net_send_object = utils.import_site_symbol(
+        __file__, _SITE_MODULE_NAME, 'net_send_object', net_send_object)
+net_recv_object = utils.import_site_symbol(
+        __file__, _SITE_MODULE_NAME, 'net_recv_object', net_recv_object)
+SyncListenServer = utils.import_site_symbol(
+        __file__, _SITE_MODULE_NAME, 'SyncListenServer', SyncListenServer)
+SyncData = utils.import_site_symbol(
+        __file__, _SITE_MODULE_NAME, 'SyncData', SyncData)
-- 
1.7.7.6

--
To unsubscribe from this list: send the line "unsubscribe kvm" in
the body of a message to majordomo@xxxxxxxxxxxxxxx
More majordomo info at  http://vger.kernel.org/majordomo-info.html


[Index of Archives]     [KVM ARM]     [KVM ia64]     [KVM ppc]     [Virtualization Tools]     [Spice Development]     [Libvirt]     [Libvirt Users]     [Linux USB Devel]     [Linux Audio Users]     [Yosemite Questions]     [Linux Kernel]     [Linux SCSI]     [XFree86]
  Powered by Linux