This patch changes the network filtering code to use libvirt's existing IPv4 and IPv6 address parsers/printers rather than my self-written ones. I am introducing a new function in network.c that counts the number of bits in a netmask and ensures that the given address is indeed a netmask, return -1 on error or values of 0-32 for IPv4 addresses and 0-128 for IPv6 addresses. I then based the function checking for valid netmask on invoking this function. Signed-off-by: Stefan Berger <stefanb@xxxxxxxxxx> --- src/conf/nwfilter_conf.c | 181 ++---------------------------- src/conf/nwfilter_conf.h | 8 - src/nwfilter/nwfilter_ebiptables_driver.c | 46 ++++--- src/util/network.c | 173 +++++++++++++++------------- src/util/network.h | 3 5 files changed, 142 insertions(+), 269 deletions(-) Index: libvirt-acl/src/conf/nwfilter_conf.h =================================================================== --- libvirt-acl.orig/src/conf/nwfilter_conf.h +++ libvirt-acl/src/conf/nwfilter_conf.h @@ -33,6 +33,8 @@ # include "util.h" # include "hash.h" # include "xml.h" +# include "network.h" + /** * Chain suffix size is: @@ -85,11 +87,7 @@ struct _nwMACAddress { typedef struct _nwIPAddress nwIPAddress; typedef nwIPAddress *nwIPAddressPtr; struct _nwIPAddress { - int isIPv6; - union { - unsigned char ipv4Addr[4]; - unsigned char ipv6Addr[16]; - } addr; + virSocketAddr addr; }; Index: libvirt-acl/src/util/network.c =================================================================== --- libvirt-acl.orig/src/util/network.c +++ libvirt-acl/src/util/network.c @@ -219,88 +219,10 @@ virSocketGetPort(virSocketAddrPtr addr) * Returns 0 in case of success and -1 in case of error */ int virSocketAddrIsNetmask(virSocketAddrPtr netmask) { - int i; - - if (netmask == NULL) - return(-1); - - if (netmask->stor.ss_family == AF_INET) { - virIPv4Addr tm; - unsigned char tmp; - int ok = 0; - - if (getIPv4Addr(netmask, &tm) < 0) - return(-1); - - for (i = 0;i < 4;i++) { - if (tm[i] != 0) - break; - } - - if (i >= 4) - return(0); - - tmp = 0xFF; - do { - if (tm[i] == tmp) { - ok = 1; - break; - } - tmp <<= 1; - } while (tmp != 0); - if (ok == 0) - return(-1); - i++; - - if (i >= 4) - return(0); - - for (;i < 4;i++) { - if (tm[i] != 0xFF) - return(-1); - } - } else if (netmask->stor.ss_family == AF_INET6) { - virIPv6Addr tm; - unsigned short tmp; - int ok = 0; - - /* - * Hum, on IPv6 people use prefixes instead of netmask - */ - if (getIPv6Addr(netmask, &tm) < 0) - return(-1); - - for (i = 0;i < 8;i++) { - if (tm[i] != 0) - break; - } - - if (i >= 8) - return(0); - - tmp = 0xFFFF; - do { - if (tm[i] == tmp) { - ok = 1; - break; - } - tmp <<= 1; - } while (tmp != 0); - if (ok == 0) - return(-1); - i++; - - if (i >= 8) - return(0); - - for (;i < 8;i++) { - if (tm[i] != 0xFFFF) - return(-1); - } - } else { - return(-1); - } - return(0); + int n = virSocketGetNumNetmaskBits(netmask); + if (n < 0) + return -1; + return 0; } /** @@ -415,3 +337,90 @@ int virSocketGetRange(virSocketAddrPtr s } return(ret); } + + +/** + * virGetNumNetmaskBits + * @netmask: the presumed netmask + * + * Get the number of netmask bits in a netmask. + * + * Returns the number of bits in the netmask or -1 if an error occurred + * or the netmask is invalid. + */ +int virSocketGetNumNetmaskBits(const virSocketAddrPtr netmask) +{ + int i, j; + int c = 0; + + if (netmask->stor.ss_family == AF_INET) { + virIPv4Addr tm; + uint8_t bit; + + if (getIPv4Addr(netmask, &tm) < 0) + return -1; + + for (i = 0; i < 4; i++) + if (tm[i] == 0xff) + c += 8; + else + break; + + if (c == 8 * 4) + return c; + + j = i << 3; + while (j < (8 * 4)) { + bit = 1 << (7 - (j & 7)); + if ((tm[j >> 3] & bit)) { + c++; + } else + break; + j++; + } + + while (j < (8 * 4)) { + bit = 1 << (7 - (j & 7)); + if ((tm[j >> 3] & bit)) + return -1; + j++; + } + + return c; + } else if (netmask->stor.ss_family == AF_INET6) { + virIPv6Addr tm; + uint16_t bit; + + if (getIPv6Addr(netmask, &tm) < 0) + return -1; + + for (i = 0; i < 8; i++) + if (tm[i] == 0xffff) + c += 16; + else + break; + + if (c == 16 * 8) + return c; + + j = i << 4; + while (j < (16 * 8)) { + bit = 1 << (15 - (j & 0xf)); + if ((tm[j >> 4] & bit)) { + c++; + } else + break; + j++; + } + + while (j < (16 * 8)) { + bit = 1 << (15 - (j & 0xf)); + if ((tm[j >> 4]) & bit) + return -1; + j++; + } + + return c; + } + return -1; +} Index: libvirt-acl/src/util/network.h =================================================================== --- libvirt-acl.orig/src/util/network.h +++ libvirt-acl/src/util/network.h @@ -48,4 +48,7 @@ int virSocketAddrIsNetmask(virSocketAddr int virSocketCheckNetmask (virSocketAddrPtr addr1, virSocketAddrPtr addr2, virSocketAddrPtr netmask); + +int virSocketGetNumNetmaskBits(const virSocketAddrPtr netmask); + #endif /* __VIR_NETWORK_H__ */ Index: libvirt-acl/src/conf/nwfilter_conf.c =================================================================== --- libvirt-acl.orig/src/conf/nwfilter_conf.c +++ libvirt-acl/src/conf/nwfilter_conf.c @@ -473,22 +473,6 @@ checkValidMask(unsigned char *data, int } -/* check for a valid IPv4 mask */ -static bool -checkIPv4Mask(enum attrDatatype datatype ATTRIBUTE_UNUSED, void *maskptr, - virNWFilterRuleDefPtr nwf ATTRIBUTE_UNUSED) -{ - return checkValidMask(maskptr, 4); -} - -static bool -checkIPv6Mask(enum attrDatatype datatype ATTRIBUTE_UNUSED, void *maskptr, - virNWFilterRuleDefPtr nwf ATTRIBUTE_UNUSED) -{ - return checkValidMask(maskptr, 16); -} - - static bool checkMACMask(enum attrDatatype datatype ATTRIBUTE_UNUSED, void *macMask, @@ -498,16 +482,6 @@ checkMACMask(enum attrDatatype datatype } -static int getMaskNumBits(const unsigned char *mask, int len) { - int i = 0; - while (i < (len << 3)) { - if (!(mask[i>>3] & (0x80 >> (i & 3)))) - break; - i++; - } - return i; -} - /* * supported arp opcode -- see 'ebtables -h arp' for the naming */ @@ -1227,21 +1201,8 @@ static bool virNWIPv4AddressParser(const char *input, nwIPAddressPtr output) { - int i; - char *endptr; - const char *n = input; - long int d; - - for (i = 0; i < 4; i++) { - d = strtol(n, &endptr, 10); - if (d < 0 || d > 255 || - (endptr - n > 3 ) || - (i <= 2 && *endptr != '.' ) || - (i == 3 && *endptr != '\0')) - return 0; - output->addr.ipv4Addr[i] = (unsigned char)d; - n = endptr + 1; - } + if (virSocketParseIpv4Addr(input, &output->addr) == -1) + return 0; return 1; } @@ -1250,81 +1211,8 @@ static bool virNWIPv6AddressParser(const char *input, nwIPAddressPtr output) { - int i, j, pos; - uint16_t n; - int shiftpos = -1; - char prevchar; - char base; - - memset(output, 0x0, sizeof(*output)); - - output->isIPv6 = 1; - - pos = 0; - i = 0; - - while (i < 8) { - j = 0; - n = 0; - while (1) { - prevchar = input[pos++]; - if (prevchar == ':' || prevchar == 0) { - if (j > 0) { - output->addr.ipv6Addr[i * 2 + 0] = n >> 8; - output->addr.ipv6Addr[i * 2 + 1] = n; - i++; - } - break; - } - - if (j >= 4) - return 0; - - if (prevchar >= '0' && prevchar <= '9') - base = '0'; - else if (prevchar >= 'a' && prevchar <= 'f') - base = 'a' - 10; - else if (prevchar >= 'A' && prevchar <= 'F') - base = 'A' - 10; - else - return 0; - n <<= 4; - n |= (prevchar - base); - j++; - } - - if (prevchar == 0) - break; - - if (input[pos] == ':') { - pos ++; - // sequence of zeros - if (prevchar != ':') - return 0; - - if (shiftpos != -1) - return 0; - - shiftpos = i; - } - } - - if (shiftpos != -1) { - if (i >= 7) - return 0; - i--; - j = 0; - while (i >= shiftpos) { - output->addr.ipv6Addr[15 - (j*2) - 1] = - output->addr.ipv6Addr[i * 2 + 0]; - output->addr.ipv6Addr[15 - (j*2) - 0] = - output->addr.ipv6Addr[i * 2 + 1]; - output->addr.ipv6Addr[i * 2 + 0] = 0; - output->addr.ipv6Addr[i * 2 + 1] = 0; - i--; - j++; - } - } + if (virSocketParseIpv6Addr(input, &output->addr) == -1) + return 0; return 1; } @@ -1442,11 +1330,10 @@ virNWFilterRuleDetailsParse(virConnectPt } else rc = -1; } else { - if (checkIPv4Mask(datatype, - ipaddr.addr.ipv4Addr, nwf)) - *(uint8_t *)storage_ptr = - getMaskNumBits(ipaddr.addr.ipv4Addr, - sizeof(ipaddr.addr.ipv4Addr)); + int_val = virSocketGetNumNetmaskBits( + &ipaddr.addr); + if (int_val >= 0) + *(uint8_t *)storage_ptr = int_val; else rc = -1; found = 1; @@ -1497,11 +1384,10 @@ virNWFilterRuleDetailsParse(virConnectPt } else rc = -1; } else { - if (checkIPv6Mask(datatype, - ipaddr.addr.ipv6Addr, nwf)) - *(uint8_t *)storage_ptr = - getMaskNumBits(ipaddr.addr.ipv6Addr, - sizeof(ipaddr.addr.ipv6Addr)); + int_val = virSocketGetNumNetmaskBits( + &ipaddr.addr); + if (int_val >= 0) + *(uint8_t *)storage_ptr = int_val; else rc = -1; found = 1; @@ -2571,43 +2457,12 @@ virNWFilterPoolObjDeleteDef(virConnectPt static void virNWIPAddressFormat(virBufferPtr buf, nwIPAddressPtr ipaddr) { - if (!ipaddr->isIPv6) { - virBufferVSprintf(buf, "%d.%d.%d.%d", - ipaddr->addr.ipv4Addr[0], - ipaddr->addr.ipv4Addr[1], - ipaddr->addr.ipv4Addr[2], - ipaddr->addr.ipv4Addr[3]); - } else { - int i; - int dcshown = 0, in_dc = 0; - unsigned short n; - while (i < 8) { - n = (ipaddr->addr.ipv6Addr[i * 2 + 0] << 8) | - ipaddr->addr.ipv6Addr[i * 2 + 1]; - if (n == 0) { - if (!dcshown) { - in_dc = 1; - if (i == 0) - virBufferAddLit(buf, ":"); - dcshown = 1; - } - if (in_dc) { - i++; - continue; - } - } - if (in_dc) { - dcshown = 1; - virBufferAddLit(buf, ":"); - in_dc = 0; - } - i++; - virBufferVSprintf(buf, "%x", n); - if (i < 8) - virBufferAddLit(buf, ":"); - } - if (in_dc) - virBufferAddLit(buf, ":"); + virSocketAddrPtr addr = &ipaddr->addr; + char *output = virSocketFormatAddr(addr); + + if (output) { + virBufferVSprintf(buf, "%s", output); + VIR_FREE(output); } } Index: libvirt-acl/src/nwfilter/nwfilter_ebiptables_driver.c =================================================================== --- libvirt-acl.orig/src/nwfilter/nwfilter_ebiptables_driver.c +++ libvirt-acl/src/nwfilter/nwfilter_ebiptables_driver.c @@ -144,7 +144,7 @@ printDataType(virConnectPtr conn, nwItemDescPtr item) { int done; - int i, pos, s; + char *data; if (printVar(conn, vars, buf, bufsize, item, &done)) return 1; @@ -154,30 +154,38 @@ printDataType(virConnectPtr conn, switch (item->datatype) { case DATATYPE_IPADDR: - if (snprintf(buf, bufsize, "%d.%d.%d.%d", - item->u.ipaddr.addr.ipv4Addr[0], - item->u.ipaddr.addr.ipv4Addr[1], - item->u.ipaddr.addr.ipv4Addr[2], - item->u.ipaddr.addr.ipv4Addr[3]) >= bufsize) { - virNWFilterReportError(conn, VIR_ERR_INVALID_NWFILTER, - _("Buffer too small for IP address")); + data = virSocketFormatAddr(&item->u.ipaddr.addr); + if (!data) { + virNWFilterReportError(conn, VIR_ERR_INTERNAL_ERROR, + _("internal IPv4 address representation " + "is bad")); + return 1; + } + if (snprintf(buf, bufsize, "%s", data) >= bufsize) { + virNWFilterReportError(conn, VIR_ERR_INTERNAL_ERROR, + _("buffer too small for IP address")); + VIR_FREE(data); return 1; } + VIR_FREE(data); break; case DATATYPE_IPV6ADDR: - pos = 0; - for (i = 0; i < 16; i++) { - s = snprintf(&buf[pos], bufsize - pos, "%x%s", - (unsigned int)item->u.ipaddr.addr.ipv6Addr[i], - ((i & 1) && (i < 15)) ? ":" : "" ); - if (s >= bufsize - pos) { - virNWFilterReportError(conn, VIR_ERR_INVALID_NWFILTER, - _("Buffer too small for IPv6 address")); - return 1; - } - pos += s; + data = virSocketFormatAddr(&item->u.ipaddr.addr); + if (!data) { + virNWFilterReportError(conn, VIR_ERR_INTERNAL_ERROR, + _("internal IPv6 address representation " + "is bad")); + return 1; + } + + if (snprintf(buf, bufsize, "%s", data) >= bufsize) { + virNWFilterReportError(conn, VIR_ERR_INTERNAL_ERROR, + _("buffer too small for IPv6 address")); + VIR_FREE(data); + return 1; } + VIR_FREE(data); break; case DATATYPE_MACADDR: -- libvir-list mailing list libvir-list@xxxxxxxxxx https://www.redhat.com/mailman/listinfo/libvir-list