0
0
mirror of https://gitlab.nic.cz/labs/bird.git synced 2025-01-18 06:51:54 +00:00

maria's test aggregator works on IPv4 as well

This commit is contained in:
Maria Matejka 2023-12-25 23:23:19 +01:00
parent 58fac4921f
commit 1a8c065a36

View File

@ -2,7 +2,11 @@
import ipaddress
class IPTrie:
rootnet = None
agrclass = None
def __init__(self, up=None):
self.children = [ None, None ]
self.local = None
@ -31,10 +35,13 @@ class IPTrie:
(self.children[0].dump([ *path, 0 ]) if self.children[0] is not None else "") + \
(self.children[1].dump([ *path, 1 ]) if self.children[1] is not None else "")
def aggregate(self, up=None, net=ipaddress.IPv6Network("::/0"), covered=None):
def aggregate(self, up=None, net=None, covered=None):
if self.children[0] is None and self.children[1] is None:
return self
if net is None:
net = self.rootnet
if self.local:
covered = self.local
else:
@ -42,7 +49,7 @@ class IPTrie:
def coveredNode(bit):
t = IPTrie(self)
t.local = AgrPointv6(list(net.subnets())[bit], covered.bucket)
t.local = self.agrclass(list(net.subnets())[bit], covered.bucket)
t.buckets.add(covered.bucket)
return t
@ -59,7 +66,7 @@ class IPTrie:
intersection = ac[0].buckets & ac[1].buckets
if len(intersection) > 0:
nap.local = AgrPointv6(net, sorted(intersection)[0])
nap.local = self.agrclass(net, sorted(intersection)[0])
nap.buckets = intersection
else:
nap.buckets = ac[0].buckets | ac[1].buckets
@ -107,6 +114,21 @@ class AgrPointv6(ipaddress.IPv6Network):
def __init__(self, net, bucket):
super().__init__(net)
self.bucket = bucket
if IPTrie.rootnet is None:
IPTrie.rootnet = ipaddress.IPv6Network("::/0")
IPTrie.agrclass = AgrPointv6
def __str__(self):
# print(type(self), super().__str__(), type(self.bucket), self.bucket)
return super().__str__() + " -> " + self.bucket
class AgrPointv4(ipaddress.IPv4Network):
def __init__(self, net, bucket):
super().__init__(net)
self.bucket = bucket
if IPTrie.rootnet is None:
IPTrie.rootnet = ipaddress.IPv4Network("0.0.0.0/0")
IPTrie.agrclass = AgrPointv4
def __str__(self):
# print(type(self), super().__str__(), type(self.bucket), self.bucket)
@ -115,12 +137,25 @@ class AgrPointv6(ipaddress.IPv6Network):
# Load
t = IPTrie()
p = input()
data = p.split(" ")
try:
t.add(AgrPointv6(data[0], data[1]))
try:
while p := input():
data = p.split(" ")
t.add(AgrPointv6(data[0], data[1]))
except EOFError:
pass
except ipaddress.AddressValueError:
t.add(AgrPointv4(data[0], data[1]))
try:
while p := input():
data = p.split(" ")
t.add(AgrPointv4(data[0], data[1]))
except EOFError:
pass
# Dump
print("Dump After Load")