commit 35e8d094eb9dc58ab1e0a0504fd4ccb6bd68cc9a Author: janik Date: Mon Mar 21 14:00:49 2022 +0100 init commit diff --git a/imggen.py b/imggen.py new file mode 100644 index 0000000..d23e64a --- /dev/null +++ b/imggen.py @@ -0,0 +1,70 @@ +from PIL import Image +import numpy +import numba + +@numba.jit(nopython=True) +def _hilbertMap(n, size): + # https://en.wikipedia.org/wiki/Hilbert_curve + def rot(n, x, y, rx, ry): + if ry == 0: + if rx == 1: + x = n-1 - x + y = n-1 - y + + x, y = y, x + + return (x, y) + + x, y = (0, 0) + t = n + s = 1 + while (s < size): + rx = 1 & int(t / 2) + ry = 1 & (t ^ rx) + x, y = rot(s, x, y, rx, ry) + x += s * rx + y += s * ry + t = int(t / 4) + s *= 2 + + return x, y + +class imggen: + x, y = (0,0) + imgraw = None + + def __init__(self, size): + if size % 2 != 0: + raise ValueError("Size needs to be divisible by 2") + + self.x = int(pow(2, size/2)) + self.y = self.x + + self.imgraw = numpy.zeros((self.x, self.y, 3), dtype=numpy.uint8) + + + def colorRange(self, start, end, color): + for p in range(start, end+1): + x, y = _hilbertMap(p, self.x) + self.imgraw[x, y] = color + + def saveImage(self, path, scale=1): + img = Image.fromarray(self.imgraw, mode="RGB") + img = img.resize([self.x * scale, self.y * scale], Image.NEAREST) + img.save(path) + + def showImage(self, scale=1): + img = Image.fromarray(self.imgraw, mode="RGB") + img = img.resize([self.x * scale, self.y * scale], Image.NEAREST) + img.show() + + def clear(self, color=(0,0,0)): + self.imgraw = numpy.full((self.x, self.y, 3), color, dtype=numpy.uint8) + + + +if __name__ == "__main__": + gen = imggen(24) + gen.clear((255, 255, 255)) + gen.colorRange(0, pow(2, 22), (73, 255, 51)) + gen.showImage() \ No newline at end of file diff --git a/ip.py b/ip.py new file mode 100644 index 0000000..63dc82a --- /dev/null +++ b/ip.py @@ -0,0 +1,201 @@ +from ast import Str +import re + +class net: + type = 0 + start = None + end = None + cidr = 0 + + def __init__(self): + pass + + def parseNet(self, netstr: str): + netstr = netstr.strip() + # basic check + if re.fullmatch("^[\.:0-9a-f]*\/[0-9]*$", netstr) is None: + raise ValueError("Net {} is malformed".format(netstr)) + + adr, cidr = netstr.split("/") + + start = ip(adr) + start.numerical = start._mask(int(cidr)) + end = ip(adr) + end.numerical = end._mask(int(cidr), True) + + self.type = start.type + self.start = start + self.end = end + self.cidr = int(cidr) + + def getSize(self): + if self.type == 0: + return 32-self.cidr + return 128-self.cidr + +def _genMask(n: int): + x = 0 + for i in range(n): + x = (x << 1) + 1 + return ~x + +class ip: + type = 0 + numerical = 0 + + def __init__(self): + pass + + def __init__(self, octets, type): + self.fromOctets(octets, type) + + def __init__(self, adr: Str): + self.parseAdrString(adr) + + def toStr(self, mask=0): + if self.type == 0: + o = [str(self.getOctet(i, mask)) for i in range(4)] + return ".".join(o) + else: + raise NotImplementedError() + + def getOctet(self, octet, mask=0): + h = 4 if self.type == 0 else 8 + o = 8 if self.type == 0 else 16 + n = self._mask(mask) + return (n >> (o * (h - octet - 1))) & ~_genMask(o) + + def fromOctets(self, octets, type): + if type == 0: + c = 4 + o = 8 + else: + c = 8 + o = 16 + + self.type = type + self.numerical = 0 + for i in range(c): + self.numerical = (self.numerical << o) + octets[i] + + def fromNumerical(self, num: int, type): + self.num = num + self.type = type + + def _mask(self, n: int, sethigh: bool=False): + h = 32 if self.type == 0 else 128 + if sethigh: + return self.numerical | ~_genMask(h-n) + else: + return self.numerical & _genMask(h-n) + + def _splitMergedOctet(self, octet: str, splits: int): + n = int(octet) + r = [0]*splits + + if n > 1 << (splits*8): + raise ValueError("Invalid Octet {}".format(octet)) + + for i in range(splits-1, -1, -1): + r[i] = n % 256 + n = n >> 8 + + return r + + def toInt(self, trunc_bits = 0): + return self.numerical >> trunc_bits + + def _parseV4(self, adr: str): + # basic validation + if re.fullmatch("^[\.0-9]*$", adr) is None: + raise ValueError("Address {} contains illegal symbols".format(adr)) + + parts = adr.split(".") + + r = [] + if len(parts) < 4: + for i in range(len(parts)-1): + r += [int(parts[i])] + r += self._splitMergedOctet(parts[-1], 4 - len(parts) + 1) + elif len(parts) == 4: + for i in range(len(parts)): + r += [int(parts[i])] + else: + raise ValueError("Invalid IPv4 {}, to many octets".format(adr)) + + # validate each octet + for p in r: + n = int(p) + if (n < 0) or (n > 255): + raise ValueError("Invalid octet {} in address {}".format(p, adr)) + + self.fromOctets(r, 0) + + def _parseV6(self, adr: str): + # ignore case + adr = adr.lower() + # basic validation + if re.fullmatch("^[:0-9a-f]*$", adr) is None: + raise ValueError("Address {} contains illegal symbols".format(adr)) + + # handle null case, because it breaks the rest + if adr == "::": + self.numerical = 0 + self.type = 1 + return + + # handle leading and trailing ellipses + # error out, when there is a ":" without an ellipsis + if adr[0] == ":": + adr = adr[1:] + if adr[0] != ":": + raise ValueError("Address :{} starts with : that is not part of an ellipsis".format(adr)) + + if adr[-1] == ":": + adr = adr[:-1] + if adr[-1] != ":": + raise ValueError("Address {}: ends with : that is not part of an ellipsis".format(adr)) + + parts = adr.split(":") + + # validate each group + empties = 0 + for p in parts: + if p == "": + empties += 1 + if empties > 1: + raise ValueError("Address can contain only one ellipsis! Illegal address {}".format(adr)) + else: + n = int(p, 16) + if (n < 0) or (n > 65535): + raise ValueError("Invalid octet {} in address {}".format(p, adr)) + + if len(parts) < 8 and empties == 0: + raise ValueError("Invalid address {}: not enough octets without ellipsis".format(p, adr)) + + if len(parts) > 8: + raise ValueError("Invalid IPv6 {}, to many octets".format(adr)) + + r = [0]*8 + + # assign octets until we reach :: or the end of the address + for i in range(len(parts)): + if parts[i] == "": + break + else: + r[i] = int(parts[i], 16) + + # assign remaining octets backwards (filling :: with 0s) + if i < 7: + i = 7 + while (parts[i - 8] != ""): + r[i] = int(parts[i - 8], 16) + i -= 1 + + self.fromOctets(r, 1) + + def parseAdrString(self, adr: str): + if adr.count(":") == 0: + self._parseV4(adr) + else: + self._parseV6(adr) diff --git a/main.py b/main.py new file mode 100644 index 0000000..d42c15e --- /dev/null +++ b/main.py @@ -0,0 +1,86 @@ +import argparse +from ctypes.wintypes import INT + +parser = argparse.ArgumentParser() +parser.add_argument('-f', help='Writes visualization PNG to this path. Defaults to \"nets.png\".', default='nets.png', metavar='PATH') +parser.add_argument('-s', help='Smallest visible net, defaults to /32 for IPv4 and /64 for IPv6', default=-1, metavar='CIDR', type=int) +parser.add_argument( + '-n', '--networks', + required=True, + help='File containing network definitions.') + +args = parser.parse_args() + +import re +import ip +import imggen + +def parseLine(line: str): + def splitC(c): + r = (c & 0xFF0000) >> 16 + g = (c & 0x00FF00) >> 8 + b = (c & 0x0000FF) >> 0 + return (r, g, b) + + if re.fullmatch("^color\ *=.*$", line) is not None: + cstr = line.split("=")[1].strip().lstrip("#x") + cv = int(cstr, 16) + return ("col", splitC(cv)) + + if re.fullmatch("^fill\ *=.*$", line) is not None: + cstr = line.split("=")[1].strip().lstrip("#x") + cv = int(cstr, 16) + return ("fil", splitC(cv)) + + if len(line) == 0 or line[0] == "#": + return ("com", line) + + if re.fullmatch("^[\.:0-9a-f]*\/[0-9]*$", line) is not None: + n = ip.net() + n.parseNet(line) + return("net", n) + + return ("nal", 0) + +with open(args.networks) as f: + snet = ip.net() + snet.parseNet(f.readline()) + + smallest = args.s + if smallest == -1: + if snet.type == 0: + smallest = 32 + else: + smallest = 64 + + if snet.type == 0: + smallest = 32 - smallest + else: + smallest = 128 - smallest + + if (snet.getSize() - smallest) % 2 != 0: + raise argparse.ArgumentError(message="The difference between the base net and the smallest net-size must be even!") + + gen = imggen.imggen(snet.getSize() - smallest) + + c = (255, 255, 255) + + for l in f.readlines(): + pl = parseLine(l.strip()) + if pl[0] == "col": + c = pl[1] + + if pl[0] == "fil": + gen.clear(pl[1]) + + if pl[0] == "net": + s = pl[1].start.toInt(smallest) - snet.start.toInt(smallest) + e = pl[1].end.toInt(smallest) - snet.start.toInt(smallest) + gen.colorRange(s, e, c) + + s = int(1024 / gen.x) + if (s < 1): + s = 1 + + gen.saveImage(args.f, s) + gen.showImage(s) \ No newline at end of file diff --git a/nets.png b/nets.png new file mode 100644 index 0000000..efb426d Binary files /dev/null and b/nets.png differ diff --git a/tests.py b/tests.py new file mode 100644 index 0000000..6c1bdaf --- /dev/null +++ b/tests.py @@ -0,0 +1,141 @@ +import unittest +import ip +import random + +class TestIPParsing(unittest.TestCase): + + def gen_ipv4(self): + a, b, c, d = [random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)] + match random.randint(0, 20): + case 0: # merge 2 octets + adr_str = "{}.{}.{}".format(a, b, (c << 8) + d) + case 1: # merge 3 octets + adr_str = "{}.{}".format(a, (b << 16) + (c << 8) + d) + case 2: # fully degenerated address + adr_str = "{}".format((a << 24) + (b << 16) + (c << 8) + d) + case _: + adr_str = "{}.{}.{}.{}".format(a, b, c, d) + return adr_str, [a, b, c, d] + + def gen_ipv6(self): + o = [] + for j in range(random.randint(4, 8)): # 4 to 8 octets with data + o += [random.randint(0, 65535)] + + split = 0 + if len(o) < 8: # fill remaining octets with 0 + split = random.randint(0, len(o)) + + # prepare string with ellipsis + h = [hex(v)[2:] for v in o[:split]] + h += [""] + h += [hex(v)[2:] for v in o[split:]] + + # if the ellipsis is at the start (or end) of an address, we need an extra ":" + if split == 0: + h = [""] + h + if split == len(o): + h += [""] + + o = o[:split] + [0] * (8 - len(o)) + o[split:] + else: # or without + h = [hex(v)[2:] for v in o] + + + # create address string from hex octets + adr_str = ":".join(h) + + return adr_str, o + + def test_ipv4(self): + ip_obj = ip.ip() + + print("\nIPv4 random cases") + + for i in range(100000): + adr_str, octs = self.gen_ipv4() + ip_obj._parseV4(adr_str) + self.assertEqual(ip_obj.octets, octs) + + print("IPv4 failure cases") + + with self.assertRaises(Exception): + ip_obj._parseV4("192.168.2.260") + + with self.assertRaises(Exception): + ip_obj._parseV4("192.168.2.20.23") + + with self.assertRaises(Exception): + ip_obj._parseV4("192.168.2.-20") + + with self.assertRaises(Exception): + ip_obj._parseV4("192.168.ab.0") + + print("IPv4 ok") + + def test_ipv6(self): + ip_obj = ip.ip() + + print("\nIPv6 random cases") + + for i in range(10000): + adr_str, octs = self.gen_ipv6() + ip_obj._parseV6(adr_str) + self.assertEqual(ip_obj.octets, octs) + + print("IPv6 special cases") + + ip_obj._parseV6("::") + self.assertEqual(ip_obj.octets, [0]*8) + + ip_obj._parseV6("1::") + self.assertEqual(ip_obj.octets, [1]+[0]*7) + + ip_obj._parseV6("::1") + self.assertEqual(ip_obj.octets, [0]*7+[1]) + + print("IPv6 failure cases") + + with self.assertRaises(Exception): + ip_obj._parseV6("192.168.2.20") + + with self.assertRaises(Exception): + ip_obj._parseV6("ff80:12::23::") + + with self.assertRaises(Exception): + ip_obj._parseV6("1:2:3:4:5:6:7:8:9") + + with self.assertRaises(Exception): + ip_obj._parseV6("ff::1ffff") + + with self.assertRaises(Exception): + ip_obj._parseV6("1:2:3:4:5:6:7") + + with self.assertRaises(Exception): + ip_obj._parseV6("1:2:3:4:5:6:7:") + + with self.assertRaises(Exception): + ip_obj._parseV6("3159:df6c:d506:60c3:9640:") + + print("IPv6 ok") + + def test_mixed(self): + ip_obj = ip.ip() + print("\nmixed random parsing") + + for i in range(10000): + if random.random() < 0.5: + adr_str, octs = self.gen_ipv4() + type = 0 + else: + adr_str, octs = self.gen_ipv6() + type = 1 + + ip_obj.parseAdrString(adr_str) + self.assertEqual(ip_obj.octets, octs) + self.assertEqual(ip_obj.type, type) + + print("mixed ok") + +if __name__ == '__main__': + unittest.main() \ No newline at end of file