from __future__ import annotations

import ast
import contextlib
import inspect
import sys
import warnings
from code import InteractiveConsole
from types import CodeType, FrameType, FunctionType
from typing import Callable

import outcome

import trio
import trio.lowlevel
from trio._util import final


class SuppressDecorator(contextlib.ContextDecorator, contextlib.suppress):
    pass


@SuppressDecorator(KeyboardInterrupt)
@trio.lowlevel.disable_ki_protection
def terminal_newline() -> None:  # TODO: test this line
    import fcntl
    import termios

    # Fake up a newline char as if user had typed it at the terminal
    try:
        fcntl.ioctl(sys.stdin, termios.TIOCSTI, b"\n")  # type: ignore[attr-defined, unused-ignore]
    except OSError as e:
        print(f"\nPress enter! Newline injection failed: {e}", end="", flush=True)


@final
class TrioInteractiveConsole(InteractiveConsole):
    def __init__(self, repl_locals: dict[str, object] | None = None) -> None:
        super().__init__(locals=repl_locals)
        self.token: trio.lowlevel.TrioToken | None = None
        self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
        self.interrupted = False

    def runcode(self, code: CodeType) -> None:
        func = FunctionType(code, self.locals)
        if inspect.iscoroutinefunction(func):
            result = trio.from_thread.run(outcome.acapture, func)
        else:
            result = trio.from_thread.run_sync(outcome.capture, func)
        if isinstance(result, outcome.Error):
            # If it is SystemExit, quit the repl. Otherwise, print the traceback.
            # If there is a SystemExit inside a BaseExceptionGroup, it probably isn't
            # the user trying to quit the repl, but rather an error in the code. So, we
            # don't try to inspect groups for SystemExit. Instead, we just print and
            # return to the REPL.
            if isinstance(result.error, SystemExit):
                raise result.error
            else:
                # Inline our own version of self.showtraceback that can use
                # outcome.Error.error directly to print clean tracebacks.
                # This also means overriding self.showtraceback does nothing.
                sys.last_type, sys.last_value = type(result.error), result.error
                sys.last_traceback = result.error.__traceback__
                # see https://docs.python.org/3/library/sys.html#sys.last_exc
                if sys.version_info >= (3, 12):
                    sys.last_exc = result.error

                # We always use sys.excepthook, unlike other implementations.
                # This means that overriding self.write also does nothing to tbs.
                sys.excepthook(sys.last_type, sys.last_value, sys.last_traceback)
        # clear any residual KI
        trio.from_thread.run(trio.lowlevel.checkpoint_if_cancelled)
        # trio.from_thread.check_cancelled() has too long of a memory

    if sys.platform == "win32":  # TODO: test this line

        def raw_input(self, prompt: str = "") -> str:
            try:
                return input(prompt)
            except EOFError:
                # check if trio has a pending KI
                trio.from_thread.run(trio.lowlevel.checkpoint_if_cancelled)
                raise

    else:

        def raw_input(self, prompt: str = "") -> str:
            from signal import SIGINT, signal

            assert not self.interrupted

            def install_handler() -> (
                Callable[[int, FrameType | None], None] | int | None
            ):
                def handler(
                    sig: int, frame: FrameType | None
                ) -> None:  # TODO: test this line
                    self.interrupted = True
                    token.run_sync_soon(terminal_newline, idempotent=True)

                token = trio.lowlevel.current_trio_token()

                return signal(SIGINT, handler)

            prev_handler = trio.from_thread.run_sync(install_handler)
            try:
                return input(prompt)
            finally:
                trio.from_thread.run_sync(signal, SIGINT, prev_handler)
                if self.interrupted:  # TODO: test this line
                    raise KeyboardInterrupt

        def write(self, output: str) -> None:
            if self.interrupted:  # TODO: test this line
                assert output == "\nKeyboardInterrupt\n"
                sys.stderr.write(output[1:])
                self.interrupted = False
            else:
                sys.stderr.write(output)


async def run_repl(console: TrioInteractiveConsole) -> None:
    banner = (
        f"trio REPL {sys.version} on {sys.platform}\n"
        f'Use "await" directly instead of "trio.run()".\n'
        f'Type "help", "copyright", "credits" or "license" '
        f"for more information.\n"
        f'{getattr(sys, "ps1", ">>> ")}import trio'
    )
    try:
        await trio.to_thread.run_sync(console.interact, banner)
    finally:
        warnings.filterwarnings(
            "ignore",
            message=r"^coroutine .* was never awaited$",
            category=RuntimeWarning,
        )


def main(original_locals: dict[str, object]) -> None:
    with contextlib.suppress(ImportError):
        import readline  # noqa: F401

    repl_locals: dict[str, object] = {"trio": trio}
    for key in {
        "__name__",
        "__package__",
        "__loader__",
        "__spec__",
        "__builtins__",
        "__file__",
    }:
        repl_locals[key] = original_locals[key]

    console = TrioInteractiveConsole(repl_locals)
    trio.run(run_repl, console)
