#!/usr/bin/python3

import cbor2
import os
import pathlib
import socket
import subprocess
import sys

DEFAULT_RUN_PATH = pathlib.Path(f"/run/user/{os.getuid()}/flock")
handlers = {}

class HandlerError(Exception):
    pass

def handler(fun):
    items = fun.__name__.split("_")

    hx = handlers
    while len(items) > 1:
        if (s := items.pop(0)) not in hx:
            hx[s] = dict()
        hx = hx[s]

    if items[0] in hx:
        raise Exception(f"Duplicate handler {fun.__name__}")

    hx[items[0]] = fun

class HypervisorNonexistentError(HandlerError):
    def __init__(self, *args, **kwargs):
        return super().__init__("Hypervisor not found", *args, **kwargs)

class HypervisorStaleError(HandlerError):
    def __init__(self, *args, **kwargs):
        return super().__init__("Hypervisor stale", *args, **kwargs)

def connect(where: pathlib.Path):
    if not where.exists():
        raise HypervisorNonexistentError()

    client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    try:
        client.connect(bytes(where))
    except ConnectionRefusedError:
        raise HypervisorStaleError()
    return client

def ctl_path(name: str):
    return DEFAULT_RUN_PATH / f"{name}.ctl"

def msg(name: str, data: dict):
    try:
        ctl = connect(ctl_path(name))
    except HypervisorNonexistentError as e:
        e.add_note(f"Failed to send message {data} to {name}")
        raise e

    ctl.sendall(cbor2.dumps(data))
    return cbor2.loads(ctl.recv(1024))

@handler
def start(name: str):
    DEFAULT_RUN_PATH.mkdir(parents=True, exist_ok=True)
    try:
        connect(ctl := ctl_path(name))
        raise HandlerError("Hypervisor already exists")
    except HypervisorNonexistentError:
        pass
    
    subprocess.run(["./flock-sim", "-s", ctl, name])

@handler
def stop(name: str):
    for k,v in msg(name, { 0: None }).items():
        assert(k == -1)
        assert(v == "OK")

@handler
def cleanup(name: str):
    try:
        connect(ctl := ctl_path(name))
        raise HandlerError("Hypervisor is not stale")
    except HypervisorStaleError:
        ctl.unlink()

@handler
def telnet(name: str):
    for k,v in msg(name, { 1: None}).items():
        assert(k == -2)
        os.execlp("telnet", "telnet", "localhost", str(v))

@handler
def container_start(hypervisor: str, name: str):
    for k,v in msg(hypervisor, { 3: {
        0: name,
        1: 1,
        2: b"/",
        3: bytes(DEFAULT_RUN_PATH / hypervisor / name),
        }}).items():
        print(k,v)

@handler
def container_stop(hypervisor: str, name: str):
    for k,v in msg(hypervisor, { 4: { 0: name, }}).items():
        print(k,v)

@handler
def container_telnet(hypervisor: str, name: str):
    for k,v in msg(hypervisor, { 1: name}).items():
        assert(k == -2)
        os.execlp("telnet", "telnet", "localhost", str(v))

try:
    binname = sys.argv.pop(0)
except Exception as e:
    raise RuntimeError from e
    
def usage(name: str):
    print(
            f"Usage: {name} <command> <args>",
            f"",
            f"Available commands:",
            f"\tstart <name>                             start Flock hypervisor",
            f"\t                                         creates <name>.ctl in {DEFAULT_RUN_PATH}",
            f"\tstop <name>                              stop Flock hypervisor",
            f"\tcleanup <name>                           cleanup the control socket left behind a stale hypervisor",
            f"\ttelnet <hypervisor>                      run telnet to hypervisor",
            f"\tcontainer start <hypervisor> <name>      start virtual machine",
            f"\tcontainer stop <hypervisor> <name>       stop virtual machine",
            f"\tcontainer telnet <hypervisor> <name>     run telnet to this machine",
            sep="\n")

cmd = []
hx = handlers
while type(hx) is dict:
    try:
        hx = hx[cx := sys.argv.pop(0)]
    except (IndexError, KeyError):
        usage(binname)
        exit(2)

    cmd.append(cx)

try:
    hx(*sys.argv)
except HandlerError as e:
    print(f"Error: {e}")
#    raise e
    exit(1)
except TypeError as e:
    usage(binname)
    print()
    print(f"Error in command {' '.join(cmd)}.")
    raise RuntimeError from e
    exit(2)