#!/usr/bin/python3

import json
import sys
import os
import pipes
import socket
from threading import Thread, Lock, Condition, RLock
import fcntl
import shutil
import time
import logging
import codecs
from contextlib import closing
from typing import *

import requests

from queue import Queue, Empty

jconf: dict


# os related funciton
def clean_up_folder(path):
    for f in os.listdir(path):
        f_path = os.path.join(path, f)
        if os.path.isfile(f_path):
            os.unlink(f_path)
        else:
            shutil.rmtree(f_path)


def execute(cmd):
    if os.system(cmd) != 0:
        raise Exception('failed to execute: %s' % cmd)


def freopen(f, g):
    os.dup2(g.fileno(), f.fileno())
    g.close()


# utf 8
def uoj_text_replace(e):
    return ''.join('<b>\\x%02x</b>' % b for b in e.object[e.start:e.end]), e.end


codecs.register_error('uoj_text_replace', uoj_text_replace)


# init
def init():
    global jconf
    os.chdir(os.path.dirname(os.path.realpath(__file__)))
    with open('.conf.json', 'r') as fp:
        jconf = json.load(fp)
        assert 'uoj_protocol' in jconf
        assert 'uoj_host' in jconf
        assert 'judger_name' in jconf
        assert 'judger_password' in jconf
        assert 'socket_port' in jconf
        assert 'socket_password' in jconf
        if 'judger_cpus' not in jconf:
            jconf['judger_cpus'] = [
                '0']  # it should be distinct physical cores! Usually, you can just set it to some even numbers


# path related function
def uoj_url(uri):
    return ("%s://%s%s" % (jconf['uoj_protocol'], jconf['uoj_host'], uri)).rstrip('/')


def uoj_judger_path(path=''):
    return os.getcwd() + "/uoj_judger" + path


class RWLock:
    def __init__(self):
        self.rwlock = 0
        self.writers_waiting = 0
        self.monitor = Lock()
        self.readers_ok = Condition(self.monitor)
        self.writers_ok = Condition(self.monitor)

    def acquire_read(self):
        self.monitor.acquire()
        while self.rwlock < 0 or self.writers_waiting:
            self.readers_ok.wait()
        self.rwlock += 1
        self.monitor.release()

    def acquire_write(self):
        self.monitor.acquire()
        while self.rwlock != 0:
            self.writers_waiting += 1
            self.writers_ok.wait()
            self.writers_waiting -= 1
        self.rwlock = -1
        self.monitor.release()

    def promote(self):
        self.monitor.acquire()
        self.rwlock -= 1
        while self.rwlock != 0:
            self.writers_waiting += 1
            self.writers_ok.wait()
            self.writers_waiting -= 1
        self.rwlock = -1
        self.monitor.release()

    def demote(self):
        self.monitor.acquire()
        self.rwlock = 1
        self.readers_ok.notifyAll()
        self.monitor.release()

    def release(self):
        self.monitor.acquire()
        if self.rwlock < 0:
            self.rwlock = 0
        else:
            self.rwlock -= 1
        wake_writers = self.writers_waiting and self.rwlock == 0
        wake_readers = self.writers_waiting == 0
        self.monitor.release()
        if wake_writers:
            self.writers_ok.acquire()
            self.writers_ok.notify()
            self.writers_ok.release()
        elif wake_readers:
            self.readers_ok.acquire()
            self.readers_ok.notify_all()
            self.readers_ok.release()


class Judger:
    problem_data_lock = RWLock()
    send_and_fetch_lock = RLock()

    judger_id: int
    cpus: str
    main_path: str
    submission: Optional[dict]
    submission_judged: bool
    main_thread: Optional[Thread]
    report_thread: Optional[Thread]
    _taskQ: Queue

    def log_judge_client_status(self):
        logging.info('submission: ' + str(self.submission))

    def __init__(self, judger_id, cpus):
        self.judger_id = judger_id
        self.cpus = cpus
        self.main_path = '/tmp/' + jconf['judger_name'] + '/' + str(self.judger_id)
        self.submission = None
        self.main_thread = Thread(target=self._main_loop)
        self.report_thread = None
        self.submission_judged = False
        self._taskQ = Queue()

    def start(self):
        execute('mkdir -p %s' % (self.main_path + '/result'))
        execute('mkdir -p %s' % (self.main_path + '/work'))

        self.log_judge_client_status()
        logging.info('hello from judger #%d' % self.judger_id)

        self.main_thread.start()

    def suspend(self):
        self._taskQ.put('suspend')
        self._taskQ.join()

    def resume(self):
        self._taskQ.put('resume')
        self._taskQ.join()

    def exit(self):
        self._taskQ.put('exit')
        self._taskQ.join()
        self.main_thread.join()

    # report thread
    def _report_loop(self):
        if 'is_hack' in self.submission:
            return
        while not self.submission_judged:
            try:
                with open(self.main_path + '/result/cur_status.txt', 'r') as f:
                    fcntl.flock(f, fcntl.LOCK_SH)
                    try:
                        status = f.read(100)
                    except Exception:
                        status = None
                    finally:
                        fcntl.flock(f, fcntl.LOCK_UN)

                if status != None:
                    data = {}
                    data['update-status'] = True
                    data['id'] = self.submission['id']
                    if 'is_custom_test' in self.submission:
                        data['is_custom_test'] = True
                    data['status'] = status
                    uoj_interact(data)
                time.sleep(0.2)
            except Exception:
                pass

    def _get_result(self):
        res = {}
        with open(self.main_path + '/result/result.txt', 'r', encoding='utf-8', errors='uoj_text_replace') as fres:
            res['score'] = 0
            res['time'] = 0
            res['memory'] = 0
            while True:
                line = fres.readline()
                if line == '':
                    break
                line = line.strip()
                if line == 'details':
                    res['details'] = fres.read()
                    break

                sp = line.split()
                assert len(sp) >= 1
                if sp[0] == 'error':
                    res['error'] = line[len('error') + 1:]
                else:
                    assert len(sp) == 2
                    res[sp[0]] = sp[1]
            try:
                res['score'] = int(res['score'])
            except ValueError:
                res['score'] = float(res['score'])
            if not (0 <= res['score'] <= 100):
                res['score'] = 0
                res['error'] = 'Score Error'
            res['time'] = int(res['time'])
            res['memory'] = int(res['memory'])
        res['status'] = 'Judged'
        return res

    def _remove_problem_data(self, problem_id):
        quoted_path = pipes.quote(uoj_judger_path(f'/data/{problem_id}'))
        execute(f'chmod 700 {quoted_path} -R && rm -rf {quoted_path}')

    def _update_problem_data_atime(self, problem_id):
        path = uoj_judger_path(f'/data/{problem_id}')
        execute(f'touch -a {pipes.quote(path)}')

    def _update_problem_data(self, problem_id, problem_mtime):
        Judger.problem_data_lock.acquire_read()
        try:
            copy_name = uoj_judger_path(f'/data/{problem_id}')
            copy_zip_name = uoj_judger_path(f'/data/{problem_id}.zip')
            if os.path.isdir(copy_name):
                if os.path.getmtime(copy_name) >= problem_mtime:
                    self._update_problem_data_atime(problem_id)
                    return
                else:
                    Judger.problem_data_lock.promote()
                    self._remove_problem_data(problem_id)
            else:
                Judger.problem_data_lock.promote()

            del_list = sorted(os.listdir(uoj_judger_path('/data')),
                              key=lambda p: os.path.getatime(uoj_judger_path(f'/data/{p}')))[:-99]
            for p in del_list:
                self._remove_problem_data(p)

            uoj_download(f'/problem/{problem_id}', copy_zip_name)
            execute('cd %s && unzip -q %d.zip && rm %d.zip && chmod -w %d -R' % (
                uoj_judger_path('/data'), problem_id, problem_id, problem_id))
            os.utime(uoj_judger_path(f'/data/{problem_id}'), (time.time(), problem_mtime))
        except Exception:
            self.log_judge_client_status()
            logging.exception('problem update error')
            raise Exception(f'failed to update problem data of #{problem_id}')
        else:
            self.log_judge_client_status()
            logging.info(f'updated problem data of #{problem_id} successfully')
        finally:
            Judger.problem_data_lock.release()

    def _judge(self):
        clean_up_folder(self.main_path + '/work')
        clean_up_folder(self.main_path + '/result')
        self._update_problem_data(self.submission['problem_id'], self.submission['problem_mtime'])

        with open(self.main_path + '/work/submission.conf', 'w') as fconf:
            uoj_download(self.submission['content']['file_name'], self.main_path + '/work/all.zip')
            execute("cd %s && unzip -q all.zip && rm all.zip" % pipes.quote(self.main_path + '/work'))
            for k, v in self.submission['content']['config']:
                print(k, v, file=fconf)

            if 'is_hack' in self.submission:
                if self.submission['hack']['input_type'] == 'USE_FORMATTER':
                    uoj_download(self.submission['hack']['input'], self.main_path + '/work/hack_input_raw.txt')
                    execute('%s <%s >%s' % (
                        pipes.quote(uoj_judger_path('/run/formatter')),
                        pipes.quote(self.main_path + '/work/hack_input_raw.txt'),
                        pipes.quote(self.main_path + '/work/hack_input.txt')
                    ))
                else:
                    uoj_download(self.submission['hack']['input'], self.main_path + '/work/hack_input.txt')
                print('test_new_hack_only on', file=fconf)
            elif 'is_custom_test' in self.submission:
                print('custom_test on', file=fconf)

        self.submission_judged = False
        self.report_thread = Thread(target=self._report_loop)
        self.report_thread.start()

        Judger.problem_data_lock.acquire_read()
        try:
            main_judger_cmd = pipes.quote(uoj_judger_path('/main_judger'))
            if self.cpus != 'all':
                main_judger_cmd = 'taskset -c %s %s' % (pipes.quote(self.cpus), main_judger_cmd)
            execute('cd %s && %s' % (pipes.quote(self.main_path), main_judger_cmd))
        finally:
            Judger.problem_data_lock.release()
        self.submission_judged = True
        self.report_thread.join()
        self.report_thread = None

        return self._get_result()

    def _send_and_fetch(self, result=None, fetch_new=True):
        """send judgement result, and fetch new submission to judge"""

        data = {}
        files = {}

        if not fetch_new:
            data['fetch_new'] = False

        if result is not None:
            data['submit'] = True
            if 'is_hack' in self.submission:
                data['is_hack'] = True
                data['id'] = self.submission['hack']['id']
                if result != False and result['score']:
                    try:
                        logging.info("succ hack!")
                        files = {
                            ('hack_input', open(self.main_path + '/work/hack_input.txt', 'rb')),
                            ('std_output', open(self.main_path + '/work/std_output.txt', 'rb'))
                        }
                    except Exception:
                        self.log_judge_client_status()
                        logging.exception('hack: submit error')
                        result = False
            elif 'is_custom_test' in self.submission:
                data['is_custom_test'] = True
                data['id'] = self.submission['id']
            else:
                data['id'] = self.submission['id']

            if 'judge_time' in self.submission:
                data['judge_time'] = self.submission['judge_time']

            if 'tid' in self.submission:
                data['tid'] = self.submission['tid']

            if result == False:
                result = {
                    'score': 0,
                    'error': 'Judgment Failed',
                    'details': 'Unknown Error'
                }
            result['status'] = 'Judged'
            data['result'] = json.dumps(result, ensure_ascii=False)

        while True:
            Judger.send_and_fetch_lock.acquire()
            try:
                ret = uoj_interact(data, files)
            except Exception:
                self.log_judge_client_status()
                logging.exception('uoj_interact error')
            else:
                break
            finally:
                Judger.send_and_fetch_lock.release()
            time.sleep(2)

        try:
            self.submission = json.loads(ret)
        except Exception:
            if ret != 'Nothing to judge':
                logging.info(ret)
            self.submission = None
            return False
        else:
            return True

    def _main_loop(self):
        while True:
            if not self._taskQ.empty():
                while True:
                    task = self._taskQ.get()
                    if task == 'resume':
                        self._taskQ.task_done()
                        break
                    elif task == 'exit':
                        self._taskQ.task_done()
                        return
                    else:
                        self._taskQ.task_done()

            if not self._send_and_fetch():
                time.sleep(2)
                continue

            self.log_judge_client_status()
            logging.info('judging')

            while True:
                try:
                    res = self._judge()
                except Exception:
                    self.log_judge_client_status()
                    logging.exception('judge error')
                    res = False
                if not self._send_and_fetch(result=res, fetch_new=self._taskQ.empty()):
                    break


class JudgerServer:
    taskQ: Queue
    judgers: List[Judger]
    socket_thread: Thread

    def __init__(self):
        self.taskQ = Queue()
        self.judgers = [Judger(i, cpus) for i, cpus in enumerate(jconf['judger_cpus'])]
        self.socket_thread = Thread(target=self._socket_loop)

    def _socket_loop(self):
        try:
            SOCK_CLOEXEC = 524288
            with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM | SOCK_CLOEXEC)) as s:
                s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                s.bind(('', jconf['socket_port']))
                s.listen(5)

                while True:
                    try:
                        conn, addr = s.accept()
                        with closing(conn) as conn:
                            data = conn.recv(1024)
                            assert data != None
                            task = json.loads(data)
                            assert task['password'] == jconf['socket_password']
                            assert 'cmd' in task

                            self.taskQ.put(task)

                            if task['cmd'] == 'stop':
                                logging.info('the judge client is closing...')
                                self.taskQ.join()
                                conn.sendall(b'ok')
                                return 'stop'
                    except Exception:
                        logging.exception('connection rejected')
                    else:
                        logging.info('a new task accomplished')
        except Exception:
            self.taskQ.put({'cmd': 'stop'})
            logging.exception('socket server error!')

    def run(self):
        self.socket_thread.start()
        for judger in self.judgers:
            judger.start()

        while True:
            task = self.taskQ.get()

            need_restart = False
            block_wait = False

            for judger in self.judgers:
                judger.suspend()

            try:
                while True:
                    if task['cmd'] in ['update', 'self-update']:
                        try:
                            uoj_download('/judger', 'judger_update.zip')
                            execute('unzip -o judger_update.zip && cd %s && make clean && make' % uoj_judger_path())
                        except:
                            print(sys.stderr, "error when update")
                        if jconf['judger_name'] == 'main_judger':
                            uoj_sync_judge_client()
                        need_restart = True
                    elif task['cmd'] == 'stop':
                        self.taskQ.task_done()
                        self.socket_thread.join()
                        for judger in self.judgers:
                            judger.exit()
                        logging.info("goodbye!")
                        sys.exit(0)

                    self.taskQ.task_done()
                    task = self.taskQ.get(block=block_wait)
            except Empty:
                pass

            if need_restart:
                os.execl('./judge_client', './judge_client')

            for judger in self.judgers:
                judger.resume()


# interact with uoj web server
# TODO: set a larger timeout
def uoj_interact(data, files={}):
    data = data.copy()
    data.update({
        'judger_name': jconf['judger_name'],
        'password': jconf['judger_password']
    })
    return requests.post(uoj_url('/judge/submit'), data=data, files=files).text


def uoj_download(uri, filename):
    data = {
        'judger_name': jconf['judger_name'],
        'password': jconf['judger_password']
    }
    with open(filename, 'wb') as f:
        r = requests.post(uoj_url('/judge/download' + uri), data=data, stream=True)
        for chunk in r.iter_content(chunk_size=65536):
            if chunk:
                f.write(chunk)


def uoj_sync_judge_client():
    data = {
        'judger_name': jconf['judger_name'],
        'password': jconf['judger_password']
    }
    ret = requests.post(uoj_url('/judge/sync-judge-client'), data=data).text
    if ret != "ok":
        raise Exception('failed to sync judge clients: %s' % ret)


# main function
def main():
    init()

    logging.basicConfig(filename='log/judge.log', level=logging.INFO,
                        format='%(asctime)-15s [%(levelname)s]: %(message)s')

    if len(sys.argv) == 1:
        JudgerServer().run()
    if len(sys.argv) == 2:
        if sys.argv[1] == 'start':
            pid = os.fork()
            if pid == -1:
                raise Exception('fork failed')
            elif pid > 0:
                return
            else:
                freopen(sys.stdout, open(os.devnull, 'w'))
                freopen(sys.stderr, open('log/judge.log', 'a', encoding='utf-8', errors='uoj_text_replace'))
                JudgerServer().run()
        elif sys.argv[1] in ['update', 'self-update']:
            try:
                try:
                    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
                        s.connect(('127.0.0.1', jconf['socket_port']))
                        s.sendall(json.dumps({
                            'password': jconf['socket_password'],
                            'cmd': sys.argv[1]
                        }).encode())
                    return
                except OSError:
                    JudgerServer.update(broadcast=sys.argv[1] == 'update')
                    return
            except Exception:
                logging.exception('update error')
                raise Exception('update failed')
        elif sys.argv[1] == 'stop':
            try:
                with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
                    s.connect(('127.0.0.1', jconf['socket_port']))
                    s.sendall(json.dumps({
                        'password': jconf['socket_password'],
                        'cmd': 'stop'
                    }).encode())
                    if s.recv(10) != b'ok':
                        raise Exception('stop failed')
                return
            except Exception:
                logging.exception('stop error')
                raise Exception('stop failed')
    raise Exception('invalid argument')


try:
    main()
except Exception:
    logging.exception('critical error!')
    sys.exit(1)