RFC: IPVS Unit Tests

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



tl;dr; I've attached the outline of basic unit tests (v4 and v6 tests for round robin, but the sky's the limit) and want feedback

As previously mentioned, we use ipvs /a lot/ at Facebook. We have a bunch of diffs that are pretty hacky to extend it to do stuff we need (mostly the ability to schedule flows at any point, because we ICMP to various ipvs instances and that can shift). Eventually my aim is for that functionality to reach you guys in a more acceptable form, but at the moment that's the state of the world.

Anyway, in the process of forward porting stuff, I wrote a ton of unit tests for the custom functionality, but it's kind of a drag to extend and I don't think any of you are going to be super thrilled about unit tests that require our open source libraries (folly) and c++11. So I started over in python to see how far I could get this afternoon.

These tests transmit over a raw socket and receive over a tun device. I was running it on a 3.2 host, but it should work all the way forward to *-next (the C++ version does). It's pretty easy to extend and mess with.

I've attached the entire script, but here's a simple test

    def do_stickiness_test(self, vip, port, real_servers, src_ip, iph_fn):
        self.tun.add_routes(real_servers)

        # For some reason this is necessary to not break ipv6?
        add_service('192.168.255.38', 999, [])
        add_service(vip, port, real_servers)

        buckets = {}
        base_port = 15000
        for i in xrange(10000):
            src = base_port + (i % 43)
            p = iph_fn(src_ip, vip, payload=TcpHeader(src, port, syn=1))
            send_raw_packet(p)
            dst_ip = parse_packet(self.tun.read()).dst
            if src not in buckets:
                buckets[src] = dst_ip
            else:
                self.assertEqual(buckets[src], dst_ip)

    def test_stickiness_v4(self):
        real_servers = ['1.2.3.%d' % i for i in xrange(5, 32)]
        self.do_stickiness_test(
            '1.2.3.4', 15213, real_servers, '9.9.9.9', Ipv4Header)

The raw socket stuff is just generally useful for other stuff too, if you're into that kind of thing.

Anyway, please let me know what you think

Thanks,

Alex
import sys
import collections
import fcntl
import os
import select
import socket
import struct
import subprocess
import unittest

def parse_payload(proto, data):
    if proto == socket.IPPROTO_IPIP:
        return Ipv4Header.parse(data)
    elif proto == socket.IPPROTO_IPV6:
        return Ipv6Header.parse(data)
    elif proto == socket.IPPROTO_TCP:
        return TcpHeader.parse(data)
    elif proto == socket.IPPROTO_UDP:
        return UdpHeader.parse(data)
    else:
        return None

def internet_csum(data):
    def carry_around_add(a, b):
        c = a + b
        return (c & 0xffff) + (c >> 16)

    def checksum(msg):
        s = 0
        for i in range(0, len(msg), 2):
            if i + 1 < len(msg):
                w = ord(msg[i]) + (ord(msg[i+1]) << 8)
            else:
                w = ord(msg[i])
            s = carry_around_add(s, w)
        return ~s & 0xffff

    return socket.htons(checksum(data))

def send_raw_packet(packet):
    s = socket.socket(packet.__class__.address_family, socket.SOCK_RAW,
                      socket.IPPROTO_RAW);
    s.sendto(packet.serialize(), 0, (packet.dst, 0))
    s.close()

def parse_packet(s):
    if s == None:
        return None

    if (ord(s[0]) >> 4) == 4:
        return Ipv4Header.parse(s)
    else:
        return Ipv6Header.parse(s)

class Tun:
    def __init__(self):
        TUNSETIFF = 0x400454ca
        IFF_TUN = 0x0001
        IFF_NO_PI = 0x1000
        self.tun = open('/dev/net/tun', 'r')
        ifr = struct.pack('16sH', 'tun%d', IFF_TUN | IFF_NO_PI)
        ifr = fcntl.ioctl(self.tun, TUNSETIFF, ifr)
        self.name = struct.unpack('16sH', ifr)[0].split('\0')[0]

        subprocess.call(['ip', 'link', 'set', 'dev', self.name, 'up'])

    def add_route(self, rt):
        subprocess.call(['ip', 'route', 'add', 'dev', self.name, rt])

    def add_routes(self, routes):
        for rt in routes:
            self.add_route(rt)

    def read(self, tmo=1):
        r, w, e = select.select([self.tun], [], [], tmo)
        if len(r):
            return os.read(self.tun.fileno(), 1500)
        else:
            return None

    def close(self):
        self.tun.close()

class Ipv4Header:
    fmt = '!BBHHHBBH4s4s'
    encaps_proto = socket.IPPROTO_IPIP
    address_family = socket.AF_INET

    def __init__(self, src, dst, proto=0, version=4, ihl=5, tos=0, tot_len=0,
                 ident=0, frag_off=0, ttl=42, csum=0, payload=None):
        self.src, self.dst, self.proto, self.version = src, dst, proto, version
        self.ihl, self.tos, self.tot_len, self.ident = ihl, tos, tot_len, ident
        self.frag_off, self.ttl, self.csum = frag_off, ttl, csum
        self.payload = payload

    @staticmethod
    def parse(data):
        fields = struct.unpack(Ipv4Header.fmt, data[:20])
        ver_ihl, tos, tot_len, ident, frag_off = fields[:5]
        ttl, proto, csum, src_raw, dst_raw = fields[5:]
        version = ver_ihl >> 4
        ihl = ver_ihl & 0xf
        src = socket.inet_ntop(socket.AF_INET, src_raw)
        dst = socket.inet_ntop(socket.AF_INET, dst_raw)
        payload = parse_payload(proto, data[ihl * 4:tot_len])
        return Ipv4Header(src=src, dst=dst, proto=proto, version=version,
                          ihl=ihl, tos=tos, tot_len=tot_len, ident=ident,
                          frag_off=frag_off, ttl=ttl, csum=csum,
                          payload=payload)

    def __pseudo_header_fn(self, length):
        src_raw = socket.inet_pton(socket.AF_INET, self.src)
        dst_raw = socket.inet_pton(socket.AF_INET, self.dst)
        proto = self.payload.__class__.encaps_proto
        return struct.pack('!4s4sBBH', src_raw, dst_raw, 0, proto, length)

    def serialize(self, ph_fn=None):
        ver_ihl = (self.version << 4) | self.ihl

        ps = self.payload.serialize(ph_fn=self.__pseudo_header_fn)
        if self.proto == 0:
            self.proto = self.payload.__class__.encaps_proto
        if self.tot_len == 0:
            self.tot_len = 20 + len(ps)

        if self.ident == 0:
            self.ident = 0xf00f

        src_raw = socket.inet_pton(socket.AF_INET, self.src)
        dst_raw = socket.inet_pton(socket.AF_INET, self.dst)
        if self.csum == 0:
            fields = [ver_ihl, self.tos, self.tot_len, self.ident,
                      self.frag_off, self.ttl, self.proto, 0, src_raw, dst_raw]
            self.csum = internet_csum(struct.pack(self.__class__.fmt, *fields))

        fields = [ver_ihl, self.tos, self.tot_len, self.ident, self.frag_off,
                  self.ttl, self.proto, self.csum, src_raw, dst_raw]

        return struct.pack(self.__class__.fmt, *fields) + ps

    def __repr__(self):
        return ('Ipv4Header(src="{s.src}", dst="{s.dst}", proto="{s.proto}", ' +
                'version={s.version}, ihl={s.ihl}, tos={s.tos}, ' +
                'tot_len={s.tot_len}, ident={s.ident}, ' +
                'frag_off={s.frag_off}, ttl={s.ttl}, csum=0x{s.csum:x}, ' +
                'payload={s.payload})').format(s=self)

class Ipv6Header:
    fmt = '!IHBB16s16s'
    encaps_proto = socket.IPPROTO_IPV6
    address_family = socket.AF_INET6

    def __init__(self, src, dst, proto=0, version=6, tclass=0, flow_label=0,
                 payload_len=0, hop_limit=42, payload=None):
        self.src, self.dst, self.proto, self.version = src, dst, proto, version
        self.tclass, self.flow_label = tclass, flow_label
        self.payload_len, self.hop_limit = payload_len, hop_limit
        self.payload = payload

    @staticmethod
    def parse(data):
        fields = struct.unpack(Ipv6Header.fmt, data[:40])
        fw, payload_len, proto, hop_limit, src_raw, dst_raw = fields
        version = (fw >> 28) & 0xf
        tclass = (fw >> 20) & 0xff
        flow_label = fw & 0xfffff
        src = socket.inet_ntop(socket.AF_INET6, src_raw)
        dst = socket.inet_ntop(socket.AF_INET6, dst_raw)
        payload = parse_payload(proto, data[40:payload_len + 40])
        return Ipv6Header(src=src, dst=dst, proto=proto, version=version,
                          tclass=tclass, flow_label=flow_label,
                          payload_len=payload_len, hop_limit=hop_limit,
                          payload=payload)

    def __pseudo_header_fn(self, length):
        src_raw = socket.inet_pton(socket.AF_INET6, self.src)
        dst_raw = socket.inet_pton(socket.AF_INET6, self.dst)
        proto = self.payload.__class__.encaps_proto
        return struct.pack('!16s16sIBBBB', src_raw, dst_raw, length, 0, 0, 0,
                           proto)

    def serialize(self, ph_fn=None):
        fw = (self.version << 28) | (self.tclass << 20) | self.flow_label
        ps = self.payload.serialize(ph_fn=self.__pseudo_header_fn)
        if self.proto == 0:
            self.proto = self.payload.__class__.encaps_proto
        if self.payload_len == 0:
            self.payload_len = len(ps)

        src_raw = socket.inet_pton(socket.AF_INET6, self.src)
        dst_raw = socket.inet_pton(socket.AF_INET6, self.dst)

        fields = [fw, self.payload_len, self.proto, self.hop_limit, src_raw,
                  dst_raw]

        return struct.pack(self.__class__.fmt, *fields) + ps

    def __repr__(self):
        return ('Ipv6Header(src="{s.src}", dst="{s.dst}", proto="{s.proto}", ' +
                'version={s.version}, tclass={s.tclass}, ' +
                'flow_label={s.flow_label}, payload_len={s.payload_len}, ' +
                'hop_limit={s.hop_limit}, payload={s.payload})').format(s=self)

class TcpHeader:
    fmt = '!HHIIHHHH'
    encaps_proto = socket.IPPROTO_TCP
    def __init__(self, src, dst, seq=0, ack_num=0, data_off=5, ns=0, cwr=0, ece=0,
                 urg=0, ack=0, psh=0, rst=0, syn=0, fin=0, win_size=100, csum=0,
                 urg_ptr=0, payload=''):
        self.src, self.dst, self.seq, self.ack_num = src, dst, seq, ack_num
        self.data_off, self.ns, self.cwr, self.ece = data_off, ns, cwr, ece
        self.urg, self.ack, self.psh, self.rst = urg, ack, psh, rst
        self.syn, self.fin, self.win_size, self.csum = syn, fin, win_size, csum
        self.urg_ptr, self.payload = urg_ptr, payload

    @staticmethod
    def parse(data):
        fields = struct.unpack(TcpHeader.fmt, data[:20])
        src, dst, seq, ack_num, flags, win_size, csum, urg_ptr = fields
        data_off = (flags >> 12) & 0xf
        fbits = [((flags >> i) & 0x1) for i in xrange(9)][::-1]
        ns, cwr, ece, urg, ack, psh, rst, syn, fin = fbits
        payload = data[(data_off * 4):]
        return TcpHeader(src=src, dst=dst, seq=seq, ack_num=ack_num, ns=ns,
                         cwr=cwr, ece=ece, urg=urg, ack=ack, psh=psh, rst=rst,
                         syn=syn, fin=fin, win_size=win_size, csum=csum,
                         urg_ptr=urg_ptr, payload=payload)

    def serialize(self, ph_fn):
        if self.data_off == 0:
            self.data_off = 5

        bits = [self.ns, self.cwr, self.ece, self.urg, self.ack, self.psh,
                self.rst, self.syn, self.fin]
        flags = reduce(lambda flags, bit: (flags << 1) | bit, bits, 0)
        flags = flags | (self.data_off << 12)

        if self.csum == 0:
            fields = [self.src, self.dst, self.seq, self.ack_num, flags,
                      self.win_size, self.csum, self.urg_ptr]
            s = struct.pack(self.__class__.fmt, *fields) + self.payload
            self.csum = internet_csum(ph_fn(20 + len(self.payload)) + s)

        fields = [self.src, self.dst, self.seq, self.ack_num, flags,
                  self.win_size, self.csum, self.urg_ptr]
        return struct.pack(self.__class__.fmt, *fields) + self.payload

    def __repr__(self):
        return ('TcpHeader(src={s.src}, dst={s.dst}, seq={s.seq}, ' +
                'ack_num={s.ack_num}, ns={s.ns}, cwr={s.cwr}, ece={s.ece}, ' +
                'urg={s.urg}, ack={s.ack}, psh={s.psh}, rst={s.rst}, ' +
                'syn={s.syn}, fin={s.fin}, win_size={s.win_size}, ' +
                'csum=0x{s.csum:x}, urg_ptr={s.urg_ptr})').format(s=self)

class UdpHeader:
    fmt = '!HHHH'
    encaps_proto = socket.IPPROTO_UDP
    def __init__(self, src, dst, csum=0, payload=''):
        self.src, self.dst, self.csum, self.payload = src, dst, csum, payload

    @staticmethod
    def parse(data):
        fields = struct.unpack(UdpHeader.fmt, data[:8])
        src, dst, csum, length = fields
        payload = data[8:length]
        return UdpHeader(src=src, dst=dst, csum=csum, payload=payload)

    def serialize(self, ph_fn):
        length = 8 + len(self.payload)
        if self.csum == 0:
            fields = [self.src, self.dst, self.csum, length]
            s = struct.pack(self.__class__.fmt, *fields) + self.payload
            self.csum = internet_csum(ph_fn(length) + s)
        fields = [self.src, self.dst, self.csum, 8 + len(self.payload)]
        return struct.pack(UdpHeader.fmt, *fields) + self.payload

    def __repr__(self):
        return ('UdpHeader(src={s.src}, dst={s.dst}, csum=0x{s.csum:x}, ' +
                'payload={s.payload})').format(s=self)

RESET_IPVS = """\
ipvsadm --clear &&
lsmod | grep '^ip_vs_' | cut -d' ' -f1 | xargs -r -n1 rmmod &&
rmmod ip_vs;
"""

def add_service(vip, port, real_servers, scheduler='rr', service_type='tcp'):
    if vip.find(':') >= 0 and vip[0] != '[':
        vip = '[' + vip + ']'
    def c(*args):
        assert subprocess.call(args) == 0
    service = '%s:%d' % (vip, port)
    sf = '--tcp-service' if service_type == 'tcp' else '--udp-service'
    ef = '--ipip'
    c('ipvsadm', '-A', sf, service, '--scheduler', scheduler)
    for rs in real_servers[::-1]:
        c('ipvsadm', '-a', sf, service, '--real-server', rs, ef)

class TestIpvsRoundRobin(unittest.TestCase):
    def setUp(self):
        subprocess.call(RESET_IPVS, shell=True)
        self.tun = Tun()

    def tearDown(self):
        self.tun.close()
        self.tun = None

    def do_balance_test(self, vip, port, real_servers, src_ip, iph_fn):
        self.tun.add_routes(real_servers)

        # For some reason this is necessary to not break ipv6?
        add_service('192.168.255.38', 999, [])
        add_service(vip, port, real_servers)

        buckets = collections.defaultdict(lambda: 0)
        base_port = 15000
        for i in xrange(10000):
            src = base_port + i
            p = iph_fn(src_ip, vip, payload=TcpHeader(src, port, syn=1))
            send_raw_packet(p)
            buckets[parse_packet(self.tun.read()).dst] += 1

        avg = float(sum(buckets.values())) / len(buckets)
        for k, v in buckets.iteritems():
            self.assertTrue(abs(v - avg) <= 1.0)

    def test_balance_v4(self):
        real_servers = ['1.2.3.%d' % i for i in xrange(5, 32)]
        self.do_balance_test(
            '1.2.3.4', 15213, real_servers, '9.9.9.9', Ipv4Header)

    def test_balance_v6(self):
        real_servers = ['face::%x' % i for i in xrange(5, 32)]
        self.do_balance_test(
            'face::4', 15213, real_servers, 'b00c::1', Ipv6Header)

    def do_stickiness_test(self, vip, port, real_servers, src_ip, iph_fn):
        self.tun.add_routes(real_servers)

        # For some reason this is necessary to not break ipv6?
        add_service('192.168.255.38', 999, [])
        add_service(vip, port, real_servers)

        buckets = {}
        base_port = 15000
        for i in xrange(10000):
            src = base_port + (i % 43)
            p = iph_fn(src_ip, vip, payload=TcpHeader(src, port, syn=1))
            send_raw_packet(p)
            dst_ip = parse_packet(self.tun.read()).dst
            if src not in buckets:
                buckets[src] = dst_ip
            else:
                self.assertEqual(buckets[src], dst_ip)

    def test_stickiness_v4(self):
        real_servers = ['1.2.3.%d' % i for i in xrange(5, 32)]
        self.do_stickiness_test(
            '1.2.3.4', 15213, real_servers, '9.9.9.9', Ipv4Header)

    def test_stickiness_v6(self):
        real_servers = ['face::%x' % i for i in xrange(5, 32)]
        self.do_stickiness_test(
            'face::4', 15213, real_servers, 'b00c::1', Ipv6Header)


if __name__ == '__main__':
    unittest.main()

[Index of Archives]     [Linux Filesystem Devel]     [Linux NFS]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux SCSI]     [X.Org]

  Powered by Linux