historical/moneyrocket.git/contrib/seeds/asmap.py

816 lines
36 KiB
Python
Raw Normal View History

2024-01-16 17:20:27 +00:00
# Copyright (c) 2022 Pieter Wuille
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
"""
This module provides the ASNEntry and ASMap classes.
"""
import copy
import ipaddress
import random
import unittest
from enum import Enum
from functools import total_ordering
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union, overload
def net_to_prefix(net: Union[ipaddress.IPv4Network,ipaddress.IPv6Network]) -> List[bool]:
"""
Convert an IPv4 or IPv6 network to a prefix represented as a list of bits.
IPv4 ranges are remapped to their IPv4-mapped IPv6 range (::ffff:0:0/96).
"""
num_bits = net.prefixlen
netrange = int.from_bytes(net.network_address.packed, 'big')
# Map an IPv4 prefix into IPv6 space.
if isinstance(net, ipaddress.IPv4Network):
num_bits += 96
netrange += 0xffff00000000
# Strip unused bottom bits.
assert (netrange & ((1 << (128 - num_bits)) - 1)) == 0
return [((netrange >> (127 - i)) & 1) != 0 for i in range(num_bits)]
def prefix_to_net(prefix: List[bool]) -> Union[ipaddress.IPv4Network,ipaddress.IPv6Network]:
"""The reverse operation of net_to_prefix."""
# Convert to number
netrange = sum(b << (127 - i) for i, b in enumerate(prefix))
num_bits = len(prefix)
assert num_bits <= 128
# Return IPv4 range if in ::ffff:0:0/96
if num_bits >= 96 and (netrange >> 32) == 0xffff:
return ipaddress.IPv4Network((netrange & 0xffffffff, num_bits - 96), True)
# Return IPv6 range otherwise.
return ipaddress.IPv6Network((netrange, num_bits), True)
# Shortcut for (prefix, ASN) entries.
ASNEntry = Tuple[List[bool], int]
# Shortcut for (prefix, old ASN, new ASN) entries.
ASNDiff = Tuple[List[bool], int, int]
class _VarLenCoder:
"""
A class representing a custom variable-length binary encoder/decoder for
integers. Each object represents a different coder, with different parameters
minval and clsbits.
The encoding is easiest to describe using an example. Let's say minval=100 and
clsbits=[4,2,2,3]. In that case:
- x in [100..115]: encoded as [0] + [4-bit BE encoding of (x-100)].
- x in [116..119]: encoded as [1,0] + [2-bit BE encoding of (x-116)].
- x in [120..123]: encoded as [1,1,0] + [2-bit BE encoding of (x-120)].
- x in [124..131]: encoded as [1,1,1] + [3-bit BE encoding of (x-124)].
In general, every number is encoded as:
- First, k "1"-bits, where k is the class the number falls in (there is one class
per element of clsbits).
- Then, a "0"-bit, unless k is the highest class, in which case there is nothing.
- Lastly, clsbits[k] bits encoding in big endian the position in its class that
number falls into.
- Every class k consists of 2^clsbits[k] consecutive integers. k=0 starts at minval,
other classes start one past the last element of the class before it.
"""
def __init__(self, minval: int, clsbits: List[int]):
"""Construct a new _VarLenCoder."""
self._minval = minval
self._clsbits = clsbits
self._maxval = minval + sum(1 << b for b in clsbits) - 1
def can_encode(self, val: int) -> bool:
"""Check whether value val is in the range this coder supports."""
return self._minval <= val <= self._maxval
def encode(self, val: int, ret: List[int]) -> None:
"""Append encoding of val onto integer list ret."""
assert self._minval <= val <= self._maxval
val -= self._minval
bits = 0
for k, bits in enumerate(self._clsbits):
if val >> bits:
# If the value will not fit in class k, subtract its range from v,
# emit a "1" bit and continue with the next class.
val -= 1 << bits
ret.append(1)
else:
if k + 1 < len(self._clsbits):
# Unless we're in the last class, emit a "0" bit.
ret.append(0)
break
# And then encode v (now the position within the class) in big endian.
ret.extend((val >> (bits - 1 - b)) & 1 for b in range(bits))
def encode_size(self, val: int) -> int:
"""Compute how many bits are needed to encode val."""
assert self._minval <= val <= self._maxval
val -= self._minval
ret = 0
bits = 0
for k, bits in enumerate(self._clsbits):
if val >> bits:
val -= 1 << bits
ret += 1
else:
ret += k + 1 < len(self._clsbits)
break
return ret + bits
def decode(self, stream, bitpos) -> Tuple[int,int]:
"""Decode a number starting at bitpos in stream, returning value and new bitpos."""
val = self._minval
bits = 0
for k, bits in enumerate(self._clsbits):
bit = 0
if k + 1 < len(self._clsbits):
bit = stream[bitpos]
bitpos += 1
if not bit:
break
val += 1 << bits
for i in range(bits):
bit = stream[bitpos]
bitpos += 1
val += bit << (bits - 1 - i)
return val, bitpos
# Variable-length encoders used in the binary asmap format.
_CODER_INS = _VarLenCoder(0, [0, 0, 1])
_CODER_ASN = _VarLenCoder(1, list(range(15, 25)))
_CODER_MATCH = _VarLenCoder(2, list(range(1, 9)))
_CODER_JUMP = _VarLenCoder(17, list(range(5, 31)))
class _Instruction(Enum):
"""One instruction in the binary asmap format."""
# A return instruction, encoded as [0], returns a constant ASN. It is followed by
# an integer using the ASN encoding.
RETURN = 0
# A jump instruction, encoded as [1,0] inspects the next unused bit in the input
# and either continues execution (if 0), or skips a specified number of bits (if 1).
# It is followed by an integer, and then two subprograms. The integer uses jump encoding
# and corresponds to the length of the first subprogram (so it can be skipped).
JUMP = 1
# A match instruction, encoded as [1,1,0] inspects 1 or more of the next unused bits
# in the input with its argument. If they all match, execution continues. If they do
# not, failure is returned. If a default instruction has been executed before, instead
# of failure the default instruction's argument is returned. It is followed by an
# integer in match encoding, and a subprogram. That value is at least 2 bits and at
# most 9 bits. An n-bit value signifies matching (n-1) bits in the input with the lower
# (n-1) bits in the match value.
MATCH = 2
# A default instruction, encoded as [1,1,1] sets the default variable to its argument,
# and continues execution. It is followed by an integer in ASN encoding, and a subprogram.
DEFAULT = 3
# Not an actual instruction, but a way to encode the empty program that fails. In the
# encoder, it is used more generally to represent the failure case inside MATCH instructions,
# which may (if used inside the context of a DEFAULT instruction) actually correspond to
# a successful return. In this usage, they're always converted to an actual MATCH or RETURN
# before the top level is reached (see make_default below).
END = 4
class _BinNode:
"""A class representing a (node of) the parsed binary asmap format."""
@overload
def __init__(self, ins: _Instruction): ...
@overload
def __init__(self, ins: _Instruction, arg1: int): ...
@overload
def __init__(self, ins: _Instruction, arg1: "_BinNode", arg2: "_BinNode"): ...
@overload
def __init__(self, ins: _Instruction, arg1: int, arg2: "_BinNode"): ...
def __init__(self, ins: _Instruction, arg1=None, arg2=None):
"""
Construct a new asmap node. Possibilities are:
- _BinNode(_Instruction.RETURN, asn)
- _BinNode(_Instruction.JUMP, node_0, node_1)
- _BinNode(_Instruction.MATCH, val, node)
- _BinNode(_Instruction.DEFAULT, asn, node)
- _BinNode(_Instruction.END)
"""
self.ins = ins
self.arg1 = arg1
self.arg2 = arg2
if ins == _Instruction.RETURN:
assert isinstance(arg1, int)
assert arg2 is None
self.size = _CODER_INS.encode_size(ins.value) + _CODER_ASN.encode_size(arg1)
elif ins == _Instruction.JUMP:
assert isinstance(arg1, _BinNode)
assert isinstance(arg2, _BinNode)
self.size = (_CODER_INS.encode_size(ins.value) + _CODER_JUMP.encode_size(arg1.size) +
arg1.size + arg2.size)
elif ins == _Instruction.DEFAULT:
assert isinstance(arg1, int)
assert isinstance(arg2, _BinNode)
self.size = _CODER_INS.encode_size(ins.value) + _CODER_ASN.encode_size(arg1) + arg2.size
elif ins == _Instruction.MATCH:
assert isinstance(arg1, int)
assert isinstance(arg2, _BinNode)
self.size = (_CODER_INS.encode_size(ins.value) + _CODER_MATCH.encode_size(arg1)
+ arg2.size)
elif ins == _Instruction.END:
assert arg1 is None
assert arg2 is None
self.size = 0
else:
assert False
@staticmethod
def make_end() -> "_BinNode":
"""Constructor for a _BinNode with just an END instruction."""
return _BinNode(_Instruction.END)
@staticmethod
def make_leaf(val: int) -> "_BinNode":
"""Constructor for a _BinNode of just a RETURN instruction."""
assert val is not None and val > 0
return _BinNode(_Instruction.RETURN, val)
@staticmethod
def make_branch(node0: "_BinNode", node1: "_BinNode") -> "_BinNode":
"""
Construct a _BinNode corresponding to running either the node0 or node1 subprogram,
based on the next input bit. It exploits shortcuts that are possible in the encoding,
and uses either a JUMP, MATCH, or END instruction.
"""
if node0.ins == _Instruction.END and node1.ins == _Instruction.END:
return node0
if node0.ins == _Instruction.END:
if node1.ins == _Instruction.MATCH and node1.arg1 <= 0xFF:
return _BinNode(node1.ins, node1.arg1 + (1 << node1.arg1.bit_length()), node1.arg2)
return _BinNode(_Instruction.MATCH, 3, node1)
if node1.ins == _Instruction.END:
if node0.ins == _Instruction.MATCH and node0.arg1 <= 0xFF:
return _BinNode(node0.ins, node0.arg1 + (1 << (node0.arg1.bit_length() - 1)),
node0.arg2)
return _BinNode(_Instruction.MATCH, 2, node0)
return _BinNode(_Instruction.JUMP, node0, node1)
@staticmethod
def make_default(val: int, sub: "_BinNode") -> "_BinNode":
"""
Construct a _BinNode that corresponds to the specified subprogram, with the specified
default value. It exploits shortcuts that are possible in the encoding, and will use
either a DEFAULT or a RETURN instruction."""
assert val is not None and val > 0
if sub.ins == _Instruction.END:
return _BinNode(_Instruction.RETURN, val)
if sub.ins in (_Instruction.RETURN, _Instruction.DEFAULT):
return sub
return _BinNode(_Instruction.DEFAULT, val, sub)
@total_ordering
class ASMap:
"""
A class whose objects represent a mapping from subnets to ASNs.
Internally the mapping is stored as a binary trie, but can be converted
from/to a list of ASNEntry objects, and from/to the binary asmap file format.
In the trie representation, nodes are represented as bare lists for efficiency
and ease of manipulation:
- [0] means an unassigned subnet (no ASN mapping for it is present)
- [int] means a subnet mapped entirely to the specified ASN.
- [node,node] means a subnet whose lower half and upper half have different
- mappings, represented by new trie nodes.
"""
def update(self, prefix: List[bool], asn: int) -> None:
"""Update this ASMap object to map prefix to the specified asn."""
assert asn == 0 or _CODER_ASN.can_encode(asn)
def recurse(node: List, offset: int) -> None:
if offset == len(prefix):
# Reached the end of prefix; overwrite this node.
node.clear()
node.append(asn)
return
if len(node) == 1:
# Need to descend into a leaf node; split it up.
oldasn = node[0]
node.clear()
node.append([oldasn])
node.append([oldasn])
# Descend into the node.
recurse(node[prefix[offset]], offset + 1)
# If the result is two identical leaf children, merge them.
if len(node[0]) == 1 and len(node[1]) == 1 and node[0] == node[1]:
oldasn = node[0][0]
node.clear()
node.append(oldasn)
recurse(self._trie, 0)
def update_multi(self, entries: List[Tuple[List[bool], int]]) -> None:
"""Apply multiple update operations, where longer prefixes take precedence."""
entries.sort(key=lambda entry: len(entry[0]))
for prefix, asn in entries:
self.update(prefix, asn)
def _set_trie(self, trie) -> None:
"""Set trie directly. Internal use only."""
def recurse(node: List) -> None:
if len(node) < 2:
return
recurse(node[0])
recurse(node[1])
if len(node[0]) == 2:
return
if node[0] == node[1]:
if len(node[0]) == 0:
node.clear()
else:
asn = node[0][0]
node.clear()
node.append(asn)
recurse(trie)
self._trie = trie
def __init__(self, entries: Optional[Iterable[ASNEntry]] = None) -> None:
"""Construct an ASMap object from an optional list of entries."""
self._trie = [0]
if entries is not None:
def entry_key(entry):
"""Sort function that places shorter prefixes first."""
prefix, asn = entry
return len(prefix), prefix, asn
for prefix, asn in sorted(entries, key=entry_key):
self.update(prefix, asn)
def lookup(self, prefix: List[bool]) -> Optional[int]:
"""Look up a prefix. Returns ASN, or 0 if unassigned, or None if indeterminate."""
node = self._trie
for bit in prefix:
if len(node) == 1:
break
node = node[bit]
if len(node) == 1:
return node[0]
return None
def _to_entries_flat(self, fill: bool = False) -> List[ASNEntry]:
"""Convert an ASMap object to a list of non-overlapping (prefix, asn) objects."""
prefix : List[bool] = []
def recurse(node: List) -> List[ASNEntry]:
ret = []
if len(node) == 1:
if node[0] > 0:
ret = [(list(prefix), node[0])]
elif len(node) == 2:
prefix.append(False)
ret = recurse(node[0])
prefix[-1] = True
ret += recurse(node[1])
prefix.pop()
if fill and len(ret) > 1:
asns = set(x[1] for x in ret)
if len(asns) == 1:
ret = [(list(prefix), list(asns)[0])]
return ret
return recurse(self._trie)
def _to_entries_minimal(self, fill: bool = False) -> List[ASNEntry]:
"""Convert a trie to a minimal list of ASNEntry objects, exploiting overlap."""
prefix : List[bool] = []
def recurse(node: List) -> (Tuple[Dict[Optional[int], List[ASNEntry]], bool]):
if len(node) == 1 and node[0] == 0:
return {None if fill else 0: []}, True
if len(node) == 1:
return {node[0]: [], None: [(list(prefix), node[0])]}, False
ret: Dict[Optional[int], List[ASNEntry]] = {}
prefix.append(False)
left, lhole = recurse(node[0])
prefix[-1] = True
right, rhole = recurse(node[1])
prefix.pop()
hole = not fill and (lhole or rhole)
def candidate(ctx: Optional[int], res0: Optional[List[ASNEntry]],
res1: Optional[List[ASNEntry]]):
if res0 is not None and res1 is not None:
if ctx not in ret or len(res0) + len(res1) < len(ret[ctx]):
ret[ctx] = res0 + res1
for ctx in set(left) | set(right):
candidate(ctx, left.get(ctx), right.get(ctx))
candidate(ctx, left.get(None), right.get(ctx))
candidate(ctx, left.get(ctx), right.get(None))
if not hole:
for ctx in list(ret):
if ctx is not None:
candidate(None, [(list(prefix), ctx)], ret[ctx])
if None in ret:
ret = {ctx:entries for ctx, entries in ret.items()
if ctx is None or len(entries) < len(ret[None])}
if hole:
ret = {ctx:entries for ctx, entries in ret.items() if ctx is None or ctx == 0}
return ret, hole
res, _ = recurse(self._trie)
return res[0] if 0 in res else res[None]
def __str__(self) -> str:
"""Convert this ASMap object to a string containing Python code constructing it."""
return f"ASMap({self._trie})"
def to_entries(self, overlapping: bool = True, fill: bool = False) -> List[ASNEntry]:
"""
Convert the mappings in this ASMap object to a list of ASNEntry objects.
Arguments:
overlapping: Permit the subnets in the resulting ASNEntry to overlap.
Setting this can result in a shorter list.
fill: Permit the resulting ASNEntry objects to cover subnets that
are unassigned in this ASMap object. Setting this can
result in a shorter list.
"""
if overlapping:
return self._to_entries_minimal(fill)
return self._to_entries_flat(fill)
@staticmethod
def from_random(num_leaves: int = 10, max_asn: int = 6,
unassigned_prob: float = 0.5) -> "ASMap":
"""
Construct a random ASMap object, with specified:
- Number of leaves in its trie (at least 1)
- Maximum ASN value (at least 1)
- Probability for leaf nodes to be unassigned
The number of leaves in the resulting object may be less than what is
requested. This method is mostly intended for testing.
"""
assert num_leaves >= 1
assert max_asn >= 1 or unassigned_prob == 1
assert _CODER_ASN.can_encode(max_asn)
assert 0.0 <= unassigned_prob <= 1.0
trie: List = []
leaves = [trie]
ret = ASMap()
for i in range(1, num_leaves):
idx = random.randrange(i)
leaf = leaves[idx]
lastleaf = leaves.pop()
if idx + 1 < i:
leaves[idx] = lastleaf
leaf.append([])
leaf.append([])
leaves.append(leaf[0])
leaves.append(leaf[1])
for leaf in leaves:
if random.random() >= unassigned_prob:
leaf.append(random.randrange(1, max_asn + 1))
else:
leaf.append(0)
#pylint: disable=protected-access
ret._set_trie(trie)
return ret
def _to_binnode(self, fill: bool = False) -> _BinNode:
"""Convert a trie to a _BinNode object."""
def recurse(node: List) -> Tuple[Dict[Optional[int], _BinNode], bool]:
if len(node) == 1 and node[0] == 0:
return {(None if fill else 0): _BinNode.make_end()}, True
if len(node) == 1:
return {None: _BinNode.make_leaf(node[0]), node[0]: _BinNode.make_end()}, False
ret: Dict[Optional[int], _BinNode] = {}
left, lhole = recurse(node[0])
right, rhole = recurse(node[1])
hole = (lhole or rhole) and not fill
def candidate(ctx: Optional[int], arg1, arg2, func: Callable):
if arg1 is not None and arg2 is not None:
cand = func(arg1, arg2)
if ctx not in ret or cand.size < ret[ctx].size:
ret[ctx] = cand
for ctx in set(left) | set(right):
candidate(ctx, left.get(ctx), right.get(ctx), _BinNode.make_branch)
candidate(ctx, left.get(None), right.get(ctx), _BinNode.make_branch)
candidate(ctx, left.get(ctx), right.get(None), _BinNode.make_branch)
if not hole:
for ctx in set(ret) - set([None]):
candidate(None, ctx, ret[ctx], _BinNode.make_default)
if None in ret:
ret = {ctx:enc for ctx, enc in ret.items()
if ctx is None or enc.size < ret[None].size}
if hole:
ret = {ctx:enc for ctx, enc in ret.items() if ctx is None or ctx == 0}
return ret, hole
res, _ = recurse(self._trie)
return res[0] if 0 in res else res[None]
@staticmethod
def _from_binnode(binnode: _BinNode) -> "ASMap":
"""Construct an ASMap object from a _BinNode. Internal use only."""
def recurse(node: _BinNode, default: int) -> List:
if node.ins == _Instruction.RETURN:
return [node.arg1]
if node.ins == _Instruction.JUMP:
return [recurse(node.arg1, default), recurse(node.arg2, default)]
if node.ins == _Instruction.MATCH:
val = node.arg1
sub = recurse(node.arg2, default)
while val >= 2:
bit = val & 1
val >>= 1
if bit:
sub = [[default], sub]
else:
sub = [sub, [default]]
return sub
assert node.ins == _Instruction.DEFAULT
return recurse(node.arg2, node.arg1)
ret = ASMap()
if binnode.ins != _Instruction.END:
#pylint: disable=protected-access
ret._set_trie(recurse(binnode, 0))
return ret
def to_binary(self, fill: bool = False) -> bytes:
"""
Convert this ASMap object to binary.
Argument:
fill: permit the resulting binary encoder to contain mappers for
unassigned subnets in this ASMap object. Doing so may
reduce the size of the encoding.
Returns:
A bytes object with the encoding of this ASMap object.
"""
bits: List[int] = []
def recurse(node: _BinNode) -> None:
_CODER_INS.encode(node.ins.value, bits)
if node.ins == _Instruction.RETURN:
_CODER_ASN.encode(node.arg1, bits)
elif node.ins == _Instruction.JUMP:
_CODER_JUMP.encode(node.arg1.size, bits)
recurse(node.arg1)
recurse(node.arg2)
elif node.ins == _Instruction.DEFAULT:
_CODER_ASN.encode(node.arg1, bits)
recurse(node.arg2)
else:
assert node.ins == _Instruction.MATCH
_CODER_MATCH.encode(node.arg1, bits)
recurse(node.arg2)
binnode = self._to_binnode(fill)
if binnode.ins != _Instruction.END:
recurse(binnode)
val = 0
nbits = 0
ret = []
for bit in bits:
val += (bit << nbits)
nbits += 1
if nbits == 8:
ret.append(val)
val = 0
nbits = 0
if nbits:
ret.append(val)
return bytes(ret)
@staticmethod
def from_binary(bindata: bytes) -> Optional["ASMap"]:
"""Decode an ASMap object from the provided binary encoding."""
bits: List[int] = []
for byte in bindata:
bits.extend((byte >> i) & 1 for i in range(8))
def recurse(bitpos: int) -> Tuple[_BinNode, int]:
insval, bitpos = _CODER_INS.decode(bits, bitpos)
ins = _Instruction(insval)
if ins == _Instruction.RETURN:
asn, bitpos = _CODER_ASN.decode(bits, bitpos)
return _BinNode(ins, asn), bitpos
if ins == _Instruction.JUMP:
jump, bitpos = _CODER_JUMP.decode(bits, bitpos)
left, bitpos1 = recurse(bitpos)
if bitpos1 != bitpos + jump:
raise ValueError("Inconsistent jump")
right, bitpos = recurse(bitpos1)
return _BinNode(ins, left, right), bitpos
if ins == _Instruction.MATCH:
match, bitpos = _CODER_MATCH.decode(bits, bitpos)
sub, bitpos = recurse(bitpos)
return _BinNode(ins, match, sub), bitpos
assert ins == _Instruction.DEFAULT
asn, bitpos = _CODER_ASN.decode(bits, bitpos)
sub, bitpos = recurse(bitpos)
return _BinNode(ins, asn, sub), bitpos
if len(bits) == 0:
binnode = _BinNode(_Instruction.END)
else:
try:
binnode, bitpos = recurse(0)
except (ValueError, IndexError):
return None
if bitpos < len(bits) - 7:
return None
if not all(bit == 0 for bit in bits[bitpos:]):
return None
return ASMap._from_binnode(binnode)
def __lt__(self, other: "ASMap") -> bool:
return self._trie < other._trie
def __eq__(self, other: object) -> bool:
if isinstance(other, ASMap):
return self._trie == other._trie
return False
def extends(self, req: "ASMap") -> bool:
"""Determine whether this matches req for all subranges where req is assigned."""
def recurse(actual: List, require: List) -> bool:
if len(require) == 1 and require[0] == 0:
return True
if len(require) == 1:
if len(actual) == 1:
return bool(require[0] == actual[0])
return recurse(actual[0], require) and recurse(actual[1], require)
if len(actual) == 2:
return recurse(actual[0], require[0]) and recurse(actual[1], require[1])
return recurse(actual, require[0]) and recurse(actual, require[1])
assert isinstance(req, ASMap)
#pylint: disable=protected-access
return recurse(self._trie, req._trie)
def diff(self, other: "ASMap") -> List[ASNDiff]:
"""Compute the diff from self to other."""
prefix: List[bool] = []
ret: List[ASNDiff] = []
def recurse(old_node: List, new_node: List):
if len(old_node) == 1 and len(new_node) == 1:
if old_node[0] != new_node[0]:
ret.append((list(prefix), old_node[0], new_node[0]))
else:
old_left: List = old_node if len(old_node) == 1 else old_node[0]
old_right: List = old_node if len(old_node) == 1 else old_node[1]
new_left: List = new_node if len(new_node) == 1 else new_node[0]
new_right: List = new_node if len(new_node) == 1 else new_node[1]
prefix.append(False)
recurse(old_left, new_left)
prefix[-1] = True
recurse(old_right, new_right)
prefix.pop()
assert isinstance(other, ASMap)
#pylint: disable=protected-access
recurse(self._trie, other._trie)
return ret
def __copy__(self) -> "ASMap":
"""Construct a copy of this ASMap object. Its state will not be shared."""
ret = ASMap()
#pylint: disable=protected-access
ret._set_trie(copy.deepcopy(self._trie))
return ret
def __deepcopy__(self, _) -> "ASMap":
# ASMap objects do not allow sharing of the _trie member, so we don't need the memoization.
return self.__copy__()
class TestASMap(unittest.TestCase):
"""Unit tests for this module."""
def test_ipv6_prefix_roundtrips(self) -> None:
"""Test that random IPv6 network ranges roundtrip through prefix encoding."""
for _ in range(20):
net_bits = random.getrandbits(128)
for prefix_len in range(0, 129):
masked_bits = (net_bits >> (128 - prefix_len)) << (128 - prefix_len)
net = ipaddress.IPv6Network((masked_bits.to_bytes(16, 'big'), prefix_len))
prefix = net_to_prefix(net)
self.assertTrue(len(prefix) <= 128)
net2 = prefix_to_net(prefix)
self.assertEqual(net, net2)
def test_ipv4_prefix_roundtrips(self) -> None:
"""Test that random IPv4 network ranges roundtrip through prefix encoding."""
for _ in range(100):
net_bits = random.getrandbits(32)
for prefix_len in range(0, 33):
masked_bits = (net_bits >> (32 - prefix_len)) << (32 - prefix_len)
net = ipaddress.IPv4Network((masked_bits.to_bytes(4, 'big'), prefix_len))
prefix = net_to_prefix(net)
self.assertTrue(32 <= len(prefix) <= 128)
net2 = prefix_to_net(prefix)
self.assertEqual(net, net2)
def test_asmap_roundtrips(self) -> None:
"""Test case that verifies random ASMap objects roundtrip to/from entries/binary."""
# Iterate over the number of leaves the random test ASMap objects have.
for leaves in range(1, 20):
# Iterate over the number of bits in the AS numbers used.
for asnbits in range(0, 24):
# Iterate over the probability that leaves are unassigned.
for pct in range(101):
# Construct a random ASMap object according to the above parameters.
asmap = ASMap.from_random(num_leaves=leaves, max_asn=1 + (1 << asnbits),
unassigned_prob=0.01 * pct)
# Run tests for to_entries and construction from those entries, both
# for overlapping and non-overlapping ones.
for overlapping in [False, True]:
entries = asmap.to_entries(overlapping=overlapping, fill=False)
random.shuffle(entries)
asmap2 = ASMap(entries)
assert asmap2 is not None
self.assertEqual(asmap2, asmap)
entries = asmap.to_entries(overlapping=overlapping, fill=True)
random.shuffle(entries)
asmap2 = ASMap(entries)
assert asmap2 is not None
self.assertTrue(asmap2.extends(asmap))
# Run tests for to_binary and construction from binary.
enc = asmap.to_binary(fill=False)
asmap3 = ASMap.from_binary(enc)
assert asmap3 is not None
self.assertEqual(asmap3, asmap)
enc = asmap.to_binary(fill=True)
asmap3 = ASMap.from_binary(enc)
assert asmap3 is not None
self.assertTrue(asmap3.extends(asmap))
def test_patching(self) -> None:
"""Test behavior of update, lookup, extends, and diff."""
#pylint: disable=too-many-locals,too-many-nested-blocks
# Iterate over the number of leaves the random test ASMap objects have.
for leaves in range(1, 20):
# Iterate over the number of bits in the AS numbers used.
for asnbits in range(0, 10):
# Iterate over the probability that leaves are unassigned.
for pct in range(0, 101):
# Construct a random ASMap object according to the above parameters.
asmap = ASMap.from_random(num_leaves=leaves, max_asn=1 + (1 << asnbits),
unassigned_prob=0.01 * pct)
# Make a copy of that asmap object to which patches will be applied.
# It starts off being equal to asmap.
patched = copy.copy(asmap)
# Keep a list of patches performed.
patches: List[ASNEntry] = []
# Initially there cannot be any difference.
self.assertEqual(asmap.diff(patched), [])
# Make 5 patches, each building on top of the previous ones.
for _ in range(0, 5):
# Construct a random path and new ASN to assign it to, apply it to patched,
# and remember it in patches.
pathlen = random.randrange(5)
path = [random.getrandbits(1) != 0 for _ in range(pathlen)]
newasn = random.randrange(1 + (1 << asnbits))
patched.update(path, newasn)
patches = [(path, newasn)] + patches
# Compute the diff, and whether asmap extends patched, and the other way
# around.
diff = asmap.diff(patched)
self.assertEqual(asmap == patched, len(diff) == 0)
extends = asmap.extends(patched)
back_extends = patched.extends(asmap)
# Determine whether those extends results are consistent with the diff
# result.
self.assertEqual(extends, all(d[2] == 0 for d in diff))
self.assertEqual(back_extends, all(d[1] == 0 for d in diff))
# For every diff found:
for path, old_asn, new_asn in diff:
# Verify asmap and patched actually differ there.
self.assertTrue(old_asn != new_asn)
self.assertEqual(asmap.lookup(path), old_asn)
self.assertEqual(patched.lookup(path), new_asn)
for _ in range(2):
# Extend the path far enough that it's smaller than any mapped
# range, and check the lookup holds there too.
spec_path = list(path)
while len(spec_path) < 32:
spec_path.append(random.getrandbits(1) != 0)
self.assertEqual(asmap.lookup(spec_path), old_asn)
self.assertEqual(patched.lookup(spec_path), new_asn)
# Search through the list of performed patches to find the last one
# applying to the extended path (note that patches is in reverse
# order, so the first match should work).
found = False
for patch_path, patch_asn in patches:
if spec_path[:len(patch_path)] == patch_path:
# When found, it must match whatever the result was patched
# to.
self.assertEqual(new_asn, patch_asn)
found = True
break
# And such a patch must exist.
self.assertTrue(found)
if __name__ == '__main__':
unittest.main()