#!/usr/bin/python3
from argparse import ArgumentParser, FileType
from enum import Enum, auto
from os import fstat, utime
from re import compile
from shutil import copyfileobj
from sys import stdout
from tempfile import TemporaryFile


class State(Enum):
    SEARCH = auto()
    SKIP = auto()
    DUMP = auto()


parser = ArgumentParser(
    description="Patch out a c4project-style dependency download"
)
parser.add_argument(
    "sourcefile",
    metavar="SOURCEFILE",
    type=FileType("r+"),
    help="Path to the source file on which to operate",
)
parser.add_argument(
    "first",
    metavar="PATTERN",
    type=compile,
    help=(
        "A Python regular expression matching the first line of the block "
        "to be patched out"
    ),
)
parser.add_argument(
    "last",
    metavar="PATTERN",
    type=compile,
    help=(
        "A Python regular expression matching the first line of the block "
        "to be patched out"
    ),
)
args = parser.parse_args()

with (
    args.sourcefile as before,
    TemporaryFile(mode="w+") as after,
    TemporaryFile(mode="w+") as skipped,
):
    st = fstat(before.fileno())
    times = st.st_atime_ns, st.st_mtime_ns
    state = State.SEARCH
    n = 0
    for line in before:
        if state == State.SEARCH:
            if args.first.search(line) is None:
                after.write(line)
                continue
            state = State.SKIP
        if state == State.SKIP:
            skipped.write(line)
            n += 1
            if args.last.search(line) is not None:
                state = State.DUMP
            continue
        if state == State.DUMP:
            after.write(line)
    if state == State.SEARCH:
        raise SystemExit(
            f"No match for {args.first.pattern!r} in {before.name}"
        )
    elif state == State.SKIP:
        raise SystemExit(
            (
                f"No match for {args.last.pattern!r} after "
                f"{args.first.pattern!r} in {before.name}"
            )
        )
    skipped.seek(0)
    after.seek(0)
    before.seek(0)
    before.truncate()
    copyfileobj(after, before)
    print(f"Patched out {n} line{'e' if n == 1 else 's'} in {before.name}:")
    bar = "=" * 78
    print(bar)
    copyfileobj(skipped, stdout)
    print(bar)

utime(args.sourcefile.name, ns=times)
