#!/usr/bin/env python3
# encoding: utf-8
"""
bgp

Created by Thomas Mangin
Copyright (c) 2013-2017 Exa Networks. All rights reserved.
License: 3-clause BSD. (See the COPYRIGHT file)
"""

import os
import socket
import sys
import time
import queue
import threading
import signal
import asyncio
import argparse

from struct import unpack, pack
from enum import Enum


class FSM(Enum):
    BUG = 0
    OPEN = 1
    MSG = 2
    STOP = 3


def bytestream(value):
    return ''.join(['%02X' % _ for _ in value])


def flushed(output):
    print(output)
    sys.stdout.flush()


def indent(msg, indent=12):
    sys.stdout.write(' ' * indent)
    flushed(msg)


def print_prefixed(prefix, message):
    flushed('{:12s}{}'.format(prefix, message))


# FIXME: we may be able to not have to use this
def msg_bytes(message):
    if len(message) % 2:
        flushed(f'invalid BGP message:\n    {message}')
    data = b''
    while message:
        b, message = message[:2], message[2:]
        value = int(b, 16)
        data += value.to_bytes(1, byteorder='big')
    return data


def format_message(header, body):
    return f'{header[:16]}:{header[16:18]}:{header[18:]}:{body}'


def print_payload(prefix, header, body):
    print_rule(prefix, bytestream(header) + bytestream(body))


def print_rule(prefix, rule):
    print_tab(prefix, f'{rule[:32]}:{rule[32:36]}:{rule[36:38]}:{rule[38:]}')


def print_tab(prefix, rule):
    flushed(f'{prefix:12}{rule}')


def dump(value):
    def spaced(value):
        even = None
        for v in value:
            if even is False:
                yield ' '
            yield '%02X' % v
            even = not even

    return ''.join(spaced(value))


def cdr_to_length(cidr):
    if cidr > 24:
        return 4
    if cidr > 16:
        return 3
    if cidr > 8:
        return 2
    if cidr > 0:
        return 1
    return 0


def any_word_in_line(line, words):
    for word in words:
        if word in line:
            return True
    return False


def any_ends_in_line(line, words):
    for word in words:
        if line.endswith(word):
            return True
    return False


def any_word_in_list(searched, words):
    for word in words:
        for search in searched:
            if word in search:
                return True
    return False


def any_ends_in_list(searched, words):
    for word in words:
        for search in searched:
            if search.endswith(word):
                return True
    return False


def find_exabgp_pid():
    # --view is argv[1]
    conf_name = sys.argv[2].split('/')[-1].split('.')[0]

    processes = []
    cmdlines = []

    for line in os.popen('/bin/ps x'):
        low = line.strip().lower()
        if not low:
            continue

        if not any_word_in_line(low, ['/python ', '/python3 ', '/pypy ', '/pypy3 ']):
            continue

        cmdline = line.strip().split()[4:]
        pid = line.strip().split()[0]

        if len(cmdline) < 1:
            continue

        if not any_ends_in_list(cmdline, ['/main.py']):
            continue

        if not any_word_in_list(cmdline, [conf_name]):
            continue

        if not any_word_in_list(cmdline, ['.conf']):
            continue

        processes.append(pid)
        cmdlines.append(' '.join(cmdline))

    if len(processes) == 0:
        flushed('no process found, quitting')
        sys.exit(1)

    if len(processes) > 1:
        flushed('more than one process running, quitting')
        sys.exit(1)

    return processes[0], cmdlines[0]


def kill(signal_name='SIGUSR1'):
    names = [name for name in dir(signal) if name.startswith('SIG')]
    signals = dict([(name, getattr(signal, name)) for name in names])
    number = signals.get(signal_name.upper(), '')

    if not number:
        raise ValueError(f'invalid signal name: {signal_name}')

    pid, cmd_lime = find_exabgp_pid()

    flushed(f'\nsending signal {signal_name} to pid {pid}')
    flushed(f'pid {pid}: {cmd_lime}\n')

    try:
        os.kill(int(pid), number)
    except Exception as exc:
        flushed('\nfailed: %s' % str(exc))
        sys.exit(1)


class Message:
    @staticmethod
    def keepalive():
        return bytearray(
            [
                0xFF,
            ]
            * 16
            + [0x0, 0x13, 0x4]
        )

    @staticmethod
    def eor():
        return bytearray(
            [
                0xFF,
            ]
            * 16
            + [0x0, 0x17, 0x02]
            + [0x00, 0x00, 0x00, 0x00]
        )

    @staticmethod
    def default_route():
        return bytearray(
            [
                0xFF,
            ]
            * 16
            + [0x00, 0x31]
            + [
                0x02,
            ]
            + [0x00, 0x00]
            + [0x00, 0x15]
            + []
            + [0x40, 0x01, 0x01, 0x00]
            + []
            + [0x40, 0x02, 0x00]
            + []
            + [0x40, 0x03, 0x04, 0x7F, 0x00, 0x00, 0x01]
            + []
            + [0x40, 0x05, 0x04, 0x00, 0x00, 0x00, 0x64]
            + [0x20, 0x00, 0x00, 0x00, 0x00]
        )

    @staticmethod
    def notify(notification):
        return bytearray(
            [
                0xFF,
            ]
            * 16
            + [0x00, 19 + 2 + len(notification)]
            + [0x03]
            + [0x06]
            + [0x00]
        ) + notification.encode('utf-8')

    _name = {
        b'\x01': 'OPEN',
        b'\x02': 'UPDATE',
        b'\x03': 'NOTIFICATION',
        b'\x04': 'KEEPALIVE',
    }

    def __init__(self, header, body):
        self.header = header
        self.body = body

    def kind(self):
        return self.header[18]

    def is_open(self):
        return self.kind() == 1

    def is_update(self):
        return self.kind() == 2

    def is_notification(self):
        return self.kind() == 3

    def is_keepalive(self):
        return self.kind() == 4

    def is_route_refresh(self):
        return self.kind() == 5

    def is_eor(self):
        if not self.is_update():
            return False
        return len(self.body) in [4, 11]

    def name(self, header):
        return self._name.get(self.kind(), 'SOME WEIRD RFC PACKET')

    def notification(self):
        yield bytestream(self.body)
        # yield 'notification:%d,%d' % (self.body[0], self.body[1])

    def stream(self):
        return bytestream(self.header + self.body)

    def packet(self):
        return bytearray(self.header + self.body)

    def routes(self):
        body_len = len(self.body)

        len_w = unpack('!H', self.body[0:2])[0]
        withdrawn = bytearray([_ for _ in self.body[2 : 2 + len_w]])

        len_a = unpack('!H', self.body[2 + len_w : 2 + len_w + 2])[0]
        announced = bytearray([_ for _ in self.body[2 + len_w + 2 + len_a :]])

        if not withdrawn and not announced:
            if body_len == 4:
                yield 'eor:1:1'
            elif body_len == 11:
                yield 'eor:%d:%d' % (self.body[-2], self.body[-1])
            else:  # undecoded MP route
                yield 'mp:'
            return

        while withdrawn:
            cdr, withdrawn = withdrawn[0], withdrawn[1:]
            size = cdr_to_length(cdr)
            r = [0, 0, 0, 0]
            for index in range(size):
                r[index], withdrawn = withdrawn[0], withdrawn[1:]
            yield 'withdraw:%s' % '.'.join(str(_) for _ in r) + '/' + str(cdr)

        while announced:
            cdr, announced = announced[0], announced[1:]
            size = cdr_to_length(cdr)
            r = [0, 0, 0, 0]
            for index in range(size):
                r[index], announced = announced[0], announced[1:]
            yield 'announce:%s' % '.'.join(str(_) for _ in r) + '/' + str(cdr)

    @staticmethod
    def _add_capa66(adding, open):
        if not adding:
            return open

        # hack capability 66 into the message
        content = b'loremipsum'
        cap66 = bytearray([66, len(content)]) + content
        param = bytearray([2, len(cap66)]) + cap66
        return (
            open[:17]
            + bytearray([open[17] + len(param)])
            + open[18:28]
            + bytearray([open[28] + len(param)])
            + open[29:]
            + param
        )

    def open(self, asn, add_capa66):
        # lower byte of the router_id
        byte_id = self.body[8]

        if not asn:
            byte_id = bytearray([byte_id + 1 & 0xFF])
            open = self.header + self.body[:8] + byte_id + self.body[9:]
            return self._add_capa66(add_capa66, open)

        # Check if peer supports Four-Octet Autonomous System (RFC6793)
        opt_params = self.body[10:]
        offset = 0
        open = ''

        while offset < len(opt_params):
            param_type = opt_params[offset]
            param_len = opt_params[offset + 1]
            if param_type == 2:  # Capabilities Optional Parameter
                cap_code = opt_params[offset + 2]
                if cap_code == 65:  # Support for 4 octet AS number capability
                    open = (
                        self.header
                        + self.body[0:1]
                        + pack('!H', asn)
                        + self.body[3:8]
                        + pack('!H', byte_id)[1:]
                        + self.body[9:10]
                        + opt_params[: offset + 2 + 2]
                        + pack('!I', asn)
                        + opt_params[offset + 2 + 2 + 4 :]
                    )
                    return self._add_capa66(add_capa66, open)
            offset += param_len + 2

        # No "Support for 4 octet AS number capability" found simply replace the 16-bit ASN number field.
        open = self.header + self.body[0:1] + pack('!H', asn) + self.body[3:8] + pack('!H', byte_id)[1:] + self.body[9:]
        return self._add_capa66(add_capa66, open)


class Checker(object):
    updates_seen = []

    def __init__(self, expected):
        self.raw = False
        self.messages = []

        self.sequences = self.group_messages(expected)

    def expecting(self):
        flushed('\nexpecting:')
        for message in self.messages:
            indent(message)

    def rules(self):
        flushed('rules:')
        idx = 0
        for messages in self.sequences:
            if idx:
                flushed('')
            idx += 1
            for message in messages:
                if message.startswith('F' * 32):
                    print_rule(f'       {idx:02}', f'{message}')
                    continue
                print_prefixed(f'       {idx:02}', f'{message}')

    # split all the messages and group them with the provided index
    # 1:raw:<raw bgp>
    # 1:raw:<raw bgp>
    # 2:raw:<raw bgp>
    # will return a list of two lists, first with 2 elements, second with one
    # A1:raw:<raw bgp>
    # B1:raw:<raw bgp>
    # B2:raw:<raw bgp>
    # C1:raw:<raw bgp>
    # is for when we have two different BGP connections

    # FIXME: we need to redo how we deal with raw vs parsed packets

    def group_messages(self, expected):
        group = {}
        for rule in expected:
            if 'notification:' not in rule:
                rule = rule.replace(' ', '').lower()

            try:
                prefix, encoding, content = rule.split(':', 2)
            except ValueError:
                flushed(f'invalid rule: {rule}')
                sys.exit(0)

            conn = prefix[0]
            if conn.isalpha():
                seq = int(prefix[1:])
            else:
                conn = 'A'
                seq = int(prefix)

            raw = encoding == 'raw'
            self.raw = self.raw or raw
            if raw:
                content = content.replace(':', '')
                group.setdefault(conn, {}).setdefault(seq, []).append(content.upper())
            else:
                group.setdefault(conn, {}).setdefault(seq, []).append(f'{encoding.lower()}:{content.lower()}')

        ordered = []
        for kg in sorted(list(group.keys())):
            source = group[kg]
            for ks in sorted(list(source.keys())):
                ordered.append(source[ks])
        return ordered

    def init(self):
        if self.messages:
            return False
        if not self.sequences:
            return False

        self.messages = self.sequences.pop(0)
        return self

    def _send_signal_if_requested(self):
        if not self.messages:
            return False

        if not self.messages[0].startswith('signal:'):
            return False

        message = self.messages.pop(0)
        self._update_messages_if_required()

        name = message.split(':')[-1]
        kill(name)
        return True

    def _send_notification_if_requested(self, writer):
        if not self.messages:
            return False

        if not self.messages[0].startswith('notification:'):
            return False

        notification = self.messages.pop(0).split(':')[-1]
        writer.write(Message.notify(notification))
        flushed(f'\nsending closing notification: "{notification}"\n')

        self._update_messages_if_required()
        return True

    def _send_rr_if_required(self, writer, msg):
        if msg.is_route_refresh():
            for header, body in self.updates_seen:
                print_payload('rr   sent', header, body)
                writer.write(Message(header, body).packet())
            writer.write(Message.keepalive())

    def perform_actions_if_required(self, writer, msg):
        self._update_messages_if_required()
        self._send_signal_if_requested()
        self._update_messages_if_required()
        self._send_notification_if_requested(writer)
        self._send_rr_if_required(writer, msg)

    def _update_messages_if_required(self):
        if not self.messages and self.sequences:
            self.messages = self.sequences.pop(0)

    def expected(self, writer, msg):
        if msg.is_keepalive():
            return True

        if not self.sequences and not self.messages:
            if msg.is_eor():
                return True
            flushed('received extra message')
            print_payload('additional', msg.header, msg.body)
            return False

        check = self.messages[0]
        stream = msg.stream()

        for check in self.messages:
            received = stream
            if not check.startswith('F' * 32) and ':' not in check:
                received = received[32:]

            if check == received:
                self.messages.remove(check)
                if msg.is_update() and not msg.is_eor():
                    self.updates_seen.append((msg.header, msg.body))

                self.perform_actions_if_required(writer, msg)
                return True

        nb_options = len(self.messages)
        flushed('')
        flushed('unexpected message:')
        print_payload('received', msg.header, msg.body)

        flushed(f'\ncounting {nb_options} valid option(s):')
        idx = 0
        for message in self.messages:
            idx += 1
            msg = msg_bytes(message)
            print_payload(f'option {idx:02}', msg[:19], msg[19:])
        flushed('')
        self.rules()
        return False

    def completed(self):
        return len(self.messages) == 0 and len(self.sequences) == 0


class BGPService(object):
    def __init__(self, loop, queue, options, checker):
        self.options = options
        self.checker = checker
        self.loop = loop
        self.queue = queue

    def exit(self, code):
        self.queue.put(code)
        self.loop.stop()


class BGPProtocol(asyncio.Protocol):
    counter = 0

    def __init__(self, service, reader, writer):
        self.service = service
        self.reader = reader
        self.writer = writer

        if not service.checker.init():
            flushed('some messages are from previous session:')

    async def read_message(self):
        try:
            header = await self.reader.read(19)
        except BrokenPipeError:
            if self.service.checker.completed():
                self.service.exit(0)
            return '', ''
        except ConnectionResetError:
            if self.service.checker.completed():
                self.service.exit(0)
            return '', ''

        if not header:
            return '', ''

        length = unpack('!H', header[16:18])[0]
        body = await self.reader.read(length - 19)
        return header, body
        # return bytearray(header), bytearray(body)

    async def handle_bgp(self):
        state = FSM.OPEN
        while state != FSM.STOP:
            header, body = await self.read_message()
            if not header:
                await asyncio.sleep(1)
                continue

            msg = Message(header, body)

            if state == FSM.OPEN:
                flushed('\nnew session:')
                print_payload('open recv', header, body)
                state = self.handle_open(msg)
                continue

            if state == FSM.MSG:
                print_payload('msg  recv', header, body)
                state = self.handle_message(msg)
                continue

            flushed(f'FSM issue {state}')
            sys.exit(1)
        self.writer.close()

    def handle_open(self, msg):
        option_asn = self.service.options['asn']
        option_capa66 = self.service.options['send-unknown-capability']
        option_default = self.service.options['send-default-route']
        option_open = self.service.options['inspect-open-message']
        option_unknown = self.service.options['send-unknown-message']

        open = msg.open(option_asn, option_capa66)
        print_payload('open sent', open[:19], open[19:])
        self.writer.write(open)
        self.writer.write(Message.keepalive())
        if option_unknown:
            # send an invalid/unknown BGP message type to exercise error handling
            unknown = b'\xff' * 16 + pack('!H', 19) + bytes([255])
            self.writer.write(unknown)

        if option_open:
            if not self.service.checker.expected(self.writer, msg):
                self.service.exit(1)
                return FSM.STOP

            if self.service.checker.completed():
                self.service.exit(0)
                return FSM.STOP

            return FSM.MSG

        self.service.checker.perform_actions_if_required(self.writer, msg)

        if option_default:
            flushed('sending default-route\n')
            self.writer.write(Message.default_route())

        return FSM.MSG

    def handle_message(self, msg):
        self.counter += 1
        header, body = msg.header, msg.body

        if self.service.options['sink']:
            print_payload(f'sank     #{self.counter}', header, body)
            self.writer.write(Message.keepalive())
            return FSM.MSG

        if self.service.options['echo']:
            print_payload(f"echo'd  #{self.counter}", header, body)
            self.writer.write(header + body)
            return FSM.MSG

        # saving update saved to send them later on rr test

        self.writer.write(Message.keepalive())

        if not self.service.checker.expected(self.writer, msg):
            self.service.exit(1)
            return FSM.STOP

        if self.service.checker.completed():
            self.service.exit(0)
            return FSM.STOP

        return FSM.MSG


def parse_cmdline():
    port = os.environ.get('exabgp.tcp.port', os.environ.get('exabgp_tcp_port', '179'))

    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('--help', help='this help :-)', action='store_true')
    parser.add_argument(
        '--echo',
        help='accept any BGP messages send it back to the emiter',
        action='store_true',
    )
    parser.add_argument(
        '--sink',
        help='accept any BGP messages and reply with a keepalive',
        action='store_true',
    )
    parser.add_argument(
        '--asn',
        help='ASN to use (otherwise extracted from the OPEN)',
        type=int,
        default=None,
    )
    parser.add_argument('--port', help='port to bind to', type=int, default=port)
    parser.add_argument('--view', help='look at the expected packets', action='store_true')
    parser.add_argument('--ipv6', help='bind using ipv6', action='store_true')
    parser.add_argument(
        'checks',
        help='a list of expected route announcement/withdrawl',
        nargs='?',
        type=open,
    )

    cmdarg = parser.parse_args()

    # Don't allow 4-byte ASN, as peer maybe not supports 4-byte asn
    if (
        cmdarg.help
        or cmdarg.port <= 0
        or cmdarg.port > 65535
        or cmdarg.asn is not None
        and (cmdarg.asn < 0 or cmdarg.asn > 65535)
    ):
        parser.print_help()
        sys.exit(1)

    # fmt: off
    options = {
        'send-unknown-capability': False,  # add an unknown capability to the open message
        'send-default-route': False,       # send a default route to the peer
        'inspect-open-message': False,     # check for received OPEN message
        'send-unknown-message': False,     # send an unknown message type to the peer
        'asn': None,                       # Don't modify the local AS per default.
        'sink': False,                     # just accept whatever is sent
        'echo': False,                     # just accept whatever is sent
        'ipv6': False,                     # bind to ipv4 by default
        'expect': []
    }

    options.update({
        'sink': cmdarg.sink,
        'echo': cmdarg.echo,
        'port': cmdarg.port,
        'asn': cmdarg.asn,
        'ipv6': cmdarg.ipv6,
        'view': cmdarg.view,
    })
    # fmt: on

    if not cmdarg.checks:
        return options

    content = [_.strip() for _ in cmdarg.checks.readlines() if _.strip() and '#' not in _]
    expect = []
    for message in content:
        if message.strip() == 'option:bind:ipv6':
            options['ipv6'] = True
            continue
        if message.strip() == 'option:open:send-unknown-capability':
            options['send-unknown-capability'] = True
            continue
        if message.strip() == 'option:open:inspect-open-message':
            options['inspect-open-message'] = True
            continue
        if message.strip() == 'option:update:send-default-route':
            options['send-default-route'] = True
            continue
        if message.strip().startswith('option:asn:'):
            options['asn'] = int(message.strip().split(':')[-1])
            continue
        if message.strip() == 'option:open:send-unknown-message':
            options['send-unknown-message'] = True
            continue
        if message.strip().startswith('option:tcp_connections:'):
            options['tcp_connections'] = int(message.strip().split(':')[-1])
            continue
        if message.strip().startswith('option:SIGUSR1:'):
            delay = int(message.split(':')[-1])
            myself = os.getpid()

            def suicide(delay, myself):
                time.sleep(delay)
                signal(myself)
                time.sleep(10)

            threading.Thread(target=suicide, args=(delay, myself))

        expect.append(message)

    options['expect'] = expect
    return options


async def main(options, checker, queue):
    # self.set_reuse_addr()
    # self.bind((host, options['port']))
    # self.listen(5)

    if options['sink']:
        flushed('\nsink mode - send us whatever, we can take it ! :p\n')
    elif options['echo']:
        flushed('\necho mode - send us whatever, we can parrot it ! :p\n')
    elif not options['expect']:
        flushed('no test data available to test against')
        sys.exit(1)

    sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
    try:
        host = '::1' if options['ipv6'] else '::FFFF:127.0.0.1'
        port = options['port']
        sock.bind((host, port))
    except PermissionError as exc:
        flushed(f'could not bind to {host}:{port}')
        flushed(exc)
        sys.exit(1)

    loop = asyncio.get_running_loop()
    service = BGPService(loop, queue, options, checker)

    server = await asyncio.start_server(
        lambda reader, writer: BGPProtocol(service, reader, writer).handle_bgp(),
        sock=sock,
    )
    # perhaps set backlog to 1 ..

    async with server:
        await server.serve_forever()


if __name__ == '__main__':
    options = parse_cmdline()
    checker = Checker(options['expect'])

    if options['view']:
        flushed('')
        checker.rules()

    queue = queue.Queue()
    try:
        asyncio.run(main(options, checker, queue))
    except RuntimeError:
        pass

    flushed('\n')

    if queue.empty():
        flushed('failed with no code')
        sys.exit(1)

    code = queue.get()
    if code != 0:
        flushed(f'failed with code {code}')
        sys.exit(code)

    flushed('successful')
    sys.exit(0)
