0
0
mirror of https://gitlab.nic.cz/labs/bird.git synced 2024-12-22 17:51:53 +00:00

Flock tests: DumpCheck class for less code duplication in checks

This commit is contained in:
Maria Matejka 2024-07-25 15:06:45 +02:00
parent 2e4da8e329
commit 3887a0bfee
3 changed files with 135 additions and 81 deletions

18
python/BIRD/Aux.py Normal file
View File

@ -0,0 +1,18 @@
import asyncio
async def dict_gather(d: dict):
return dict(zip(d.keys(), await asyncio.gather(*d.values())))
def dict_expand(d: dict):
out = {}
for k,v in d.items():
p,*r = k
if p not in out:
out[p] = {}
try:
r ,= r
except ValueError:
r = tuple(r)
out[p][r] = v
return out

View File

@ -222,6 +222,122 @@ class BIRDInstance(CLI):
self.bindir.cleanup(self.workdir) self.bindir.cleanup(self.workdir)
class DumpCheck:
def __init__(self, test, timeout, name, check_timeout=None, check_id=None, check_retry_timeout=0.5):
self.timeout = timeout
self.check_timeout = timeout if check_timeout is None else check_timeout
self.check_retry_timeout = check_retry_timeout
self.name = name
self.show_difs = test.show_difs
# Compile dump ID
if check_id is None:
try:
test.check_id += 1
except AttributeError:
test.check_id = 1
self.id = test.check_id
else:
self.id = check_id
if name is None:
self.stem = f"{self.id:04d}"
else:
self.stem = f"{self.id:04d}-{name}"
self.file = f"dump-{self.stem}.yaml"
match test.mode:
case Test.SAVE:
self.run = self.save
case Test.CHECK:
self.run = self.check
case _:
raise Exception("Invalid test mode")
def __call__(self):
print(f"{self.stem}\t", end="", flush=True)
return self.run()
async def save(self):
await asyncio.sleep(self.timeout)
dump = await self.obtain()
with open(self.file, "w") as y:
yaml.dump(dump, y)
print(f"[ SAVED ]")
async def check(self):
with open(self.file, "r") as y:
c = yaml.safe_load(y)
seen = []
try:
async with asyncio.timeout(self.check_timeout) as to:
while True:
dump = await self.obtain()
try:
deep_eq(c, dump, True)
# if deep_eq(c, dump):
spent = asyncio.get_running_loop().time() - to.when() + self.check_timeout
print(f"[ OOK ]\t{spent:.6f}s")
return True
except Differs as d:
if self.show_difs:
print(f"Differs at {' -> '.join([str(s) for s in reversed(d.tree)])}: {d.a} != {d.b}")
seen.append(dump)
await asyncio.sleep(self.check_retry_timeout)
except TimeoutError as e:
print(f"[ BAD ]")
for q in range(len(seen)):
with open(f"__result_bad_{q}__{self.stem}", "w") as y:
yaml.dump(seen[q], y)
return False
class DumpOnMachines(DumpCheck):
def __init__(self, test, *args, machines=None, **kwargs):
super().__init__(test, *args, **kwargs)
# Collect machines to dump
if machines is None:
self.machines = test.machine_index.values()
else:
self.machines = [
m if isinstance(m, CLI) else test.machine_index[m]
for m in machines
]
async def obtain(self):
return await dict_gather({
m.mach.name: self.obtain_on_machine(m)
for m in self.machines
})
class DumpRIB(DumpOnMachines):
def __init__(self, *args, full=True, **kwargs):
super().__init__(*args, **kwargs)
self.args = []
if full:
self.args.append("all")
async def obtain_on_machine(self, mach):
d = await mach.show_route(args=self.args)
assert("version" in d)
del d["version"]
for t in d["tables"].values():
for n in t["networks"].values():
for r in n["routes"]:
for k in ("when", "!_l", "!_g", "!_s", "!_id"):
assert(k in r)
del r[k]
return d
class Test: class Test:
ipv6_prefix = ipaddress.ip_network("2001:db8::/32") ipv6_prefix = ipaddress.ip_network("2001:db8::/32")
@ -353,86 +469,6 @@ class Test:
print("cleaning up") print("cleaning up")
await self.cleanup() await self.cleanup()
async def route_dump(self, timeout, name, full=True, machines=None, check_timeout=10, check_retry_timeout=0.5):
# Compile dump ID
self.route_dump_id += 1
if name is None:
name = f"dump-{self.route_dump_id:04d}.yaml"
else:
name = f"dump-{self.route_dump_id:04d}-{name}.yaml"
print(f"{name}\t{self.route_dump_id}\t", end="", flush=True)
# Collect machines to dump
if machines is None:
machines = self.machine_index.values()
else:
machines = [
m if isinstance(m, CLI) else self.machine_index[m]
for m in machines
]
# Compile command
args = []
if full:
args.append("all")
# Define the obtainer function
async def obtain():
dump = await asyncio.gather(*[
where.show_route(args=args)
for where in machines
])
for d in dump:
for t in d["tables"].values():
for n in t["networks"].values():
for r in n["routes"]:
for k in ("when", "!_l", "!_g", "!_s", "!_id"):
assert(k in r)
del r[k]
assert("version" in d)
del d["version"]
return dump
match self.mode:
case Test.SAVE:
await asyncio.sleep(timeout)
dump = await obtain()
with open(name, "w") as y:
yaml.dump_all(dump, y)
print(f"[ SAVED ]")
case Test.CHECK:
with open(name, "r") as y:
c = [*yaml.safe_load_all(y)]
seen = []
try:
async with asyncio.timeout(check_timeout) as to:
while True:
dump = await obtain()
try:
deep_eq(c, dump, True)
# if deep_eq(c, dump):
spent = asyncio.get_running_loop().time() - to.when() + check_timeout
print(f"[ OOK ]\t{spent:.6f}s")
return True
except Differs as d:
if self.show_difs:
print(f"Differs at {' -> '.join([str(s) for s in reversed(d.tree)])}: {d.a} != {d.b}")
seen.append(dump)
await asyncio.sleep(check_retry_timeout)
except TimeoutError as e:
print(f"[ BAD ]")
for q in range(len(seen)):
with open(f"__result_bad_{q}__{name}", "w") as y:
yaml.dump_all(seen[q], y)
return False
case _:
raise Exception("Invalid test mode")
if __name__ == "__main__": if __name__ == "__main__":
name = sys.argv[1] name = sys.argv[1]

@ -1 +1 @@
Subproject commit 8a733f18b122470e092915013e2b2cec1cb2baec Subproject commit d30489bec7ed8fa4d2009f785fb6f53cff6625c5