1
\$\begingroup\$

I have written a Python script that turns a gigantic text file into a SQLite3 database.

The text file is truly gigantic, it is 133MiB or 140,371,572 bytes in size, and contains 140,371,572 characters and 5,558,701 lines.

The text file contains a list of all registered autonomous system numbers and their names, followed by a list of all public IPv4 networks and their ASN and country code, followed by all the list for all public IPv6 networks.

I was instructed not to share the data without following the license requirements, I didn't read the requirements, but I got the data from here, you can download it if you want.

Sample of ASN list:

aut-num:                 AS1
name:                    LVLT-1

aut-num:                 AS2
name:                    UDEL-DCN

aut-num:                 AS3
name:                    MIT-GATEWAYS

aut-num:                 AS4
name:                    ISI-AS

Sample of IPv4 list:

net:                     1.0.0.0/8
country:                 AU

net:                     1.0.0.0/24
country:                 AU
aut-num:                 13335
is-anycast:              yes

net:                     1.0.1.0/24
country:                 CN

net:                     1.0.2.0/23
country:                 CN

Sample IPv6 list:

net:                     2000::/12
aut-num:                 3356

net:                     2001::/32
aut-num:                 6939

net:                     2001:4:112::/48
aut-num:                 112

net:                     2001:200::/23
country:                 AU
aut-num:                 3356

Example IPv4 output:

[
    ["1.0.0.0/24", 16777216, 16777471, "1.0.0.0", "1.0.0.255", 256, 13335, "AU", false, true, false, false],
    ["1.0.1.0/23", 16777472, 16777983, "1.0.1.0", "1.0.2.255", 512, null, "CN", false, false, false, false],
    ["1.0.3.0/24", 16777984, 16778239, "1.0.3.0", "1.0.3.255", 256, null, "CN", false, false, false, false],
    ["1.0.4.0/22", 16778240, 16779263, "1.0.4.0", "1.0.7.255", 1024, 38803, "AU", false, false, false, false]
]

Example IPv6 output:

[
    ["2000::/16", 42535295865117307932921825928971026432, 42540488161975842760550356425300246527, "2000::", "2000:ffff:ffff:ffff:ffff:ffff:ffff:ffff", 5192296858534827628530496329220096, 3356, null, false, false, false, false],
    ["2001::/32", 42540488161975842760550356425300246528, 42540488241204005274814694018844196863, "2001::", "2001::ffff:ffff:ffff:ffff:ffff:ffff", 79228162514264337593543950336, 6939, null, false, false, false, false],
    ["2001:1::/31", 42540488241204005274814694018844196864, 42540488399660330303343369205932097535, "2001:1::", "2001:2:ffff:ffff:ffff:ffff:ffff:ffff", 158456325028528675187087900672, 3356, null, false, false, false, false]
]

My script does a lot of things, first obviously it reads the data, then it splits the data into datum entries, and then it standardizes the length of the entries, because as you see the number of key-value pairs in each entry is different, it sets the omitted values with default value, and sorts the keys so that the data can be put into a SQLite3 table.

Then it parses the network slash notation, into starting and ending IP addresses and their integer representations and number of addresses inside the network...

But I didn't stop here, because, as you can clearly see, the networks can overlap each other, and this happens very often. And adjacent networks (two networks are adjacent if the start of one is immediately after the end of the other) can share the same data, it is a huge mess.

So I spent days trying to find ways to:

  • Merge the overlapping networks if they share the same data.

  • Split the overlapping networks into discrete networks if they don't share the same data, data is always inherited from the smaller network.

  • Add network ranges uncovered by the children networks, the added ranges inherit data from the parent network.

  • Merge adjacent networks if they share the same data.

  • And finally reorganize the whole thing so the network slash notations are correct.

I did everything I intended, and I did absolutely the best of the best I could.

However I don't think my code is concise and elegant and efficient, it took:

  • 0.691 seconds to load the text file

  • 30.856 seconds to split the data into entries

  • 147.904 seconds to analyze and reorganize the data

  • 10.575 seconds to write to SQLite3 database

  • 15.253 seconds to write to JSON files

  • 205.874 seconds to complete the whole thing.

The script:

import json
import re
import sqlite3
import time
from collections import deque
from enum import Enum
from functools import reduce
from operator import iconcat
from pathlib import Path
from typing import Deque, Iterable, List, Sequence, Tuple

script_start = time.time()

MAX_SQLITE_INT = 2 ** 63 - 1
MAX_IPV4 = 2**32-1
MAX_IPV6 = 2**128-1
IPV6DIGITS = set('0123456789abcdef:')
le255 = r'(25[0-5]|2[0-4]\d|[01]?\d\d?)'
IPV4_PATTERN = re.compile(rf'^{le255}\.{le255}\.{le255}\.{le255}$')
EMPTY = re.compile(r':?\b(?:0\b:?)+')
FIELDS = re.compile(r'::?|[\da-f]{1,4}')

JSON_TYPES = {
    True: 'true',
    False: 'false',
    None: 'null'
}

KEYS_RENAME = [
    ('net', 'network'),
    ('aut-num', 'ASN'),
    ('country', 'country_code'),
    ('is-anonymous-proxy', 'is_anonymous_proxy'),
    ('is-anycast', 'is_anycast'),
    ('is-satellite-provider', 'is_satellite_provider'),
    ('drop', 'bad')
]

CONTINENT_CODES = {
    'AF': 'Africa',
    'AN': 'Antarctica',
    'AS': 'Asia',
    'EU': 'Europe',
    'NA': 'North America',
    'OC': 'Oceania',
    'SA': 'South America'
}

DEFAULTS = {
    'network': None,
    'ASN': None,
    'country_code': None,
    'is_anonymous_proxy': False,
    'is_anycast': False,
    'is_satellite_provider': False,
    'bad': False
}

COLUMNS_SQL = """
create table if not exists {}(
    network text primary key,
    start_integer int not null,
    end_integer int not null,
    start_string text not null,
    end_string text not null,
    count int not null,
    ASN int, country_code text,
    is_anonymous_proxy int default 0,
    is_anycast int default 0,
    is_satellite_provider int default 0,
    bad int default 0
);
"""

sqlite3.register_adapter(
    int, lambda x: hex(x) if x > MAX_SQLITE_INT else x)
sqlite3.register_converter(
    'integer', lambda b: int(b, 16 if b[:2] == b'0x' else 10))
conn = sqlite3.connect('D:/network_guard/IPFire_locations.db',
                       detect_types=sqlite3.PARSE_DECLTYPES)
cur = conn.cursor()

COUNTRY_CODES = Path(
    'D:/network_guard/countries.txt').read_text(encoding='utf8').splitlines()
COUNTRY_CODES = [i.split('\t') for i in COUNTRY_CODES]
COUNTRY_CODES = [(a, c, b, CONTINENT_CODES[b]) for a, b, c in COUNTRY_CODES]

data_loading_start = time.time()
DATA = Path(
    'D:/network_guard/database.txt').read_text(encoding='utf8').splitlines()
DATA = (i for i in DATA if not i.startswith('#'))
print(time.time() - data_loading_start)
cur.execute("""
create table if not exists Country_Codes(
    country_code text primary key,
    country_name text not null,
    continent_code text not null,
    continent_name text not null
)
""")
cur.execute("""
create table if not exists Autonomous_Systems(
    ASN text primary key,
    entity text not null
)
""")
cur.execute(COLUMNS_SQL.format('IPv4'))
cur.execute(COLUMNS_SQL.format('IPv6'))
cur.executemany(
    "insert into Country_Codes values (?, ?, ?, ?);", COUNTRY_CODES)
conn.commit()
KEY_ORDER = [
    'ASN',
    'country_code',
    'is_anonymous_proxy',
    'is_anycast',
    'is_satellite_provider',
    'bad'
]


def parse_ipv4(ip: str) -> int:
    assert (match := IPV4_PATTERN.match(ip))
    a, b, c, d = match.groups()
    return (int(a) << 24) + (int(b) << 16) + (int(c) << 8) + int(d)


def to_ipv4(n: int) -> str:
    assert 0 <= n <= MAX_IPV4
    return f'{n >> 24 & 255}.{n >> 16 & 255}.{n >> 8 & 255}.{n & 255}'


def preprocess_ipv6(ip: str) -> Tuple[list, bool, int]:
    assert 2 < len(ip) <= 39 and not set(ip) - IPV6DIGITS
    compressed = False
    terminals = False
    segments = ip.lower().split(':')
    if not segments[0]:
        assert not segments[1]
        terminals = 1
        compressed = True
        segments = segments[2:]
    if not segments[-1]:
        assert not compressed and not segments[-2]
        terminals = 2
        segments = segments[:-2]

    return segments, compressed, terminals


def split_ipv6(ip: str) -> Deque[str]:
    segments, compressed, terminals = preprocess_ipv6(ip)
    chunks = deque()
    if terminals == 1:
        chunks.append('::')
    assert len(segments) <= 8 - (compressed or bool(terminals))
    for seg in segments:
        if not seg:
            assert not compressed
            chunks.append('::')
            compressed = True
        else:
            assert len(seg) <= 4
            chunks.append(seg)
    if terminals == 2:
        chunks.append('::')
    return chunks


def parse_ipv6(ip: str) -> int:
    if ip == '::':
        return 0
    segments = split_ipv6(ip)
    pos = 7
    n = 0
    for i, seg in enumerate(segments):
        if seg == '::':
            pos = len(segments) - i - 2
        else:
            n += int(seg, 16) << pos*16
            pos -= 1
    return n


def to_ipv6(n: int, compress: bool = True) -> str:
    assert 0 <= n <= MAX_IPV6
    ip = '{:039_x}'.format(n).replace('_', ':')
    if compress:
        ip = ':'.join(
            s.lstrip('0')
            if s != '0000'
            else '0'
            for s in ip.split(':')
        )
        longest = max(EMPTY.findall(ip), key=len, default='')
        if len(longest) > 2:
            ip = ip.replace(longest, '::', 1)
    return ip


class IPv4(Enum):
    parser = parse_ipv4
    formatter = to_ipv4
    power = 32
    maximum = MAX_IPV4


class IPv6(Enum):
    parser = parse_ipv6
    formatter = to_ipv6
    power = 128
    maximum = MAX_IPV6


class IP_Address:
    def __str__(self) -> str:
        return self.string

    def __int__(self) -> int:
        return self.integer

    def __repr__(self) -> str:
        return f"IP_Address(integer={self.integer}, string='{self.string}', version={self.version.__name__})"

    def __init__(self, value: int | str, version: IPv4 | IPv6 = None) -> None:
        if isinstance(value, str):
            self.from_string(value, version)
        else:
            self.from_integer(value, version)

        self.maxpower = 32 if self.version == IPv4 else 128
        self.hexadecimal = f'{self.integer:08x}' if version == IPv4 else f'{self.integer:032x}'
        self.pointer = None

    def from_string(self, value: str, version: IPv4 | IPv6) -> None:
        if not version:
            version = IPv4 if IPV4_PATTERN.match(value) else IPv6
        self.integer = version.parser(value)
        self.string = value
        self.version = version

    def from_integer(self, value: int, version: IPv4 | IPv6) -> None:
        assert isinstance(value, int)
        if not version:
            version = IPv4 if 0 <= value <= MAX_IPV4 else IPv6
        self.string = version.formatter(value)
        self.integer = value
        self.version = version

    def get_pointer(self) -> str:
        if not self.pointer:
            if self.version == IPv4:
                self.pointer = '.'.join(
                    self.string.split('.')[::-1]
                ) + '.in-addr.arpa'
            else:
                self.pointer = '.'.join(self.hexadecimal[::-1])+'.ip6.arpa'
        return self.pointer


def parse_network(network: str) -> Tuple[int, int]:
    start, slash = network.split('/')
    start = IP_Address(start)
    slash = int(slash)
    count = 2 ** (start.maxpower - slash)
    end = IP_Address(start.integer + count - 1)
    return start.integer, end.integer


def parse_entry(entry: Deque[tuple]) -> list:
    entry = dict(entry)
    entry = {
        new_key: v
        if (v := entry.get(
            key, DEFAULTS[new_key]
        )
        ) != 'yes'
        else True
        for key, new_key
        in KEYS_RENAME
    }
    assert (network := entry['network'])
    asn = entry['ASN']
    if asn != None:
        entry['ASN'] = int(asn)

    return [*parse_network(network), *(entry[k] for k in KEY_ORDER)]


data_processing_start = time.time()
ASN = deque()
IPV4_TABLE = deque()
IPV6_TABLE = deque()
entry = deque()
for line in DATA:
    if not line:
        if entry:
            if entry[0][0] == 'aut-num':
                ASN.append((entry[0][1], entry[1][1]))
            else:
                assert entry[0][0] == 'net'
                ip = entry[0][1].split('/')[0]
                table = IPV4_TABLE if IPV4_PATTERN.match(ip) else IPV6_TABLE
                table.append(parse_entry(entry))
        entry = deque()
    else:
        key, val = line.split(':', 1)
        entry.append((key, val.strip()))
print(time.time() - data_processing_start)


class TreeDict(dict):
    def __init__(self, *args, **kwargs) -> None:
        super(TreeDict, self).__init__(*args, **kwargs)

    def add(self, k: Sequence) -> None:
        child = self.setdefault(k[0], TreeDict())
        if len(k) > 1:
            child.add(k[1:])

    def trim_end(self) -> dict:
        result = {
            'nested': {},
            'flat': []
        }
        for k, v in self.items():
            if v:
                result['nested'][k] = v.trim_end()
            else:
                result['flat'].append(k)
        if not result['flat']:
            result.pop('flat')
        if not result['nested']:
            result.pop('nested')
        return result

    def pretty(self) -> str:
        return json.dumps(self.trim_end(), indent=4)


def split_ranges(chunks: deque, stack: deque, parent: tuple) -> Deque[list]:
    processed, parent_end, *parent_data = parent
    for start, end, *_ in stack:
        if processed < start:
            chunks.append([processed, start-1, *parent_data])
        processed = end + 1
    if processed < parent_end:
        chunks.append([processed, parent_end, *parent_data])
    return chunks


def flatten_tree(tree: dict) -> List[list]:
    flattened = []
    if (chunks := tree.get('chunks')):
        flattened.extend(chunks)
    if (subtree := tree.get('subtree')):
        flattened = reduce(iconcat, map(
            flatten_tree, subtree.values()), flattened)
    return flattened


def merge_data(data: List[list]) -> Deque[list]:
    chunks = deque()
    last = [False]*8
    for row in data:
        if row[2:] != last[2:] or row[0] > last[1] + 1:
            chunks.append(row)
            last = row
        else:
            last[1] = row[1]
    return chunks


def to_terminal(node: Tuple[int, int, bool]) -> Tuple[Tuple[int, int, bool], Tuple[int, int, bool]]:
    start, end, i = node
    return (start, i, False), (end, i, True)


class Table_Analyzer:
    def __init__(self, table: Deque[list]) -> None:
        indices = ((start, end, i) for i, (start, end, *_) in enumerate(table))
        nodes = reduce(iconcat, map(to_terminal, indices), [])
        nodes.sort()
        self.nodes = nodes
        self.table = table
        self.tree = None
        self.analyzed_tree = None
        self.result = None

    def get_tree(self) -> None:
        tree = TreeDict()
        stack = []
        last = 0
        for number, i, end in self.nodes:
            if end:
                if number != last:
                    tree.add(stack)
                last = number
                if i in stack:
                    stack.remove(i)
            elif not stack or self.table[i][2:] != self.table[stack[-1]][2:]:
                stack.append(i)

        self.tree = tree

    def analyze_tree(self, tree: TreeDict, parent: tuple = None) -> dict:
        subtree = {}
        chunks = deque()
        stack = deque()
        for k, v in tree.items():
            item = self.table[k]
            stack.append(item)
            if not v:
                chunks.append(item)
            else:
                node = tuple(item)
                subtree[node] = self.analyze_tree(v, node)
        if parent:
            chunks = split_ranges(chunks, stack, parent)
        return {'subtree': subtree, 'chunks': chunks}

    def analyze(self) -> Deque[list]:
        if not self.result:
            if not self.tree:
                self.get_tree()
            if not self.analyzed_tree:
                self.analyzed_tree = self.analyze_tree(self.tree)
            self.result = merge_data(sorted(flatten_tree(self.analyzed_tree)))
        return self.result


def binary_decompose(n: int) -> List[int]:
    return [1 << i for i, b in enumerate(bin(n).removeprefix('0b')[::-1]) if b == '1']


def powers_of_2(n: int) -> Deque[int]:
    powers = deque()
    i = 0
    total = n.bit_count()
    for p in range(n.bit_length(), -1, -1):
        if (power := 1 << p) & n:
            powers.append(power)
            if (i := i + 1) == total:
                break
    return powers


def format_network(start_ip: str, count: int, version: IPv4 | IPv6) -> str:
    slash = version.power.value - count.bit_length() + 1
    return f'{start_ip}/{slash}'


def get_network(start: int, count: int, version: IPv4 | IPv6, data: Sequence) -> tuple:
    start_ip = version.formatter(start)
    end = start + count - 1
    end_ip = version.formatter(end)
    network = format_network(start_ip, count, version)
    return (network, start, end, start_ip, end_ip, count, *data)


def to_network(item: Sequence, version: IPv4 | IPv6 = IPv4) -> Deque[tuple]:
    start, end, *data = item
    count = end - start + 1
    subnets = deque()
    powers = powers_of_2(count)
    for power in powers:
        subnets.append(get_network(start, power, version, data))
        start += power
    return subnets


data_analyzing_start = time.time()
ipv4_analyzer = Table_Analyzer(IPV4_TABLE)
IPV4_TABLE = reduce(iconcat, map(to_network, ipv4_analyzer.analyze()), [])
ipv6_analyzer = Table_Analyzer(IPV6_TABLE)
IPV6_TABLE = reduce(iconcat, map(
    lambda x: to_network(x, IPv6), ipv6_analyzer.analyze()), [])
print(time.time() - data_analyzing_start)

data_saving_start = time.time()
cur.executemany("insert into Autonomous_Systems values (?, ?);", ASN)
cur.executemany(
    f"insert into IPv4 values ({', '.join(['?']*12)});", IPV4_TABLE)
cur.executemany(
    f"insert into IPv6 values ({', '.join(['?']*12)});", IPV6_TABLE)
conn.commit()
print(time.time() - data_saving_start)


def json_repr(row: Sequence) -> str:
    items = deque()
    for e in row:
        if isinstance(e, (int, float)) and not isinstance(e, bool):
            items.append(f'{e}')
        elif isinstance(e, str):
            item = e.replace('"', '\\"')
            items.append(f'"{item}"')
        else:
            items.append(JSON_TYPES[e])
    return '['+', '.join(items)+']'


def pretty_table(table: Iterable[Sequence]) -> str:
    return '[\n'+'\n'.join(f'\t{json_repr(row)},' for row in table)+'\n]'


data_serializing_start = time.time()
Path('D:/network_guard/IPv4_table.json').write_text(pretty_table(IPV4_TABLE))
Path('D:/network_guard/IPv6_table.json').write_text(pretty_table(IPV6_TABLE))
Path('D:/network_guard/Autonomous_Systems.json').write_text(pretty_table(ASN), encoding='utf8')
Path('D:/network_guard/Country_Codes.json').write_text(pretty_table(COUNTRY_CODES))
print(time.time() - data_serializing_start)
print(time.time() - script_start)

How can I make it more efficient, concise, and elegant?

\$\endgroup\$

0

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.