from __future__ import annotations

import inspect
from types import TracebackType
from typing import Any, List, Generic, Iterable, Awaitable, cast
from typing_extensions import Self, Callable, Iterator, AsyncIterator

from ._types import ParsedResponseSnapshot
from ._events import (
    ResponseStreamEvent,
    ResponseTextDoneEvent,
    ResponseCompletedEvent,
    ResponseTextDeltaEvent,
    ResponseFunctionCallArgumentsDeltaEvent,
)
from ...._types import NOT_GIVEN, NotGiven
from ...._utils import is_given, consume_sync_iterator, consume_async_iterator
from ...._models import build, construct_type_unchecked
from ...._streaming import Stream, AsyncStream
from ....types.responses import ParsedResponse, ResponseStreamEvent as RawResponseStreamEvent
from ..._parsing._responses import TextFormatT, parse_text, parse_response
from ....types.responses.tool_param import ToolParam
from ....types.responses.parsed_response import (
    ParsedContent,
    ParsedResponseOutputMessage,
    ParsedResponseFunctionToolCall,
)


class ResponseStream(Generic[TextFormatT]):
    def __init__(
        self,
        *,
        raw_stream: Stream[RawResponseStreamEvent],
        text_format: type[TextFormatT] | NotGiven,
        input_tools: Iterable[ToolParam] | NotGiven,
        starting_after: int | None,
    ) -> None:
        self._raw_stream = raw_stream
        self._response = raw_stream.response
        self._iterator = self.__stream__()
        self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
        self._starting_after = starting_after

    def __next__(self) -> ResponseStreamEvent[TextFormatT]:
        return self._iterator.__next__()

    def __iter__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]:
        for item in self._iterator:
            yield item

    def __enter__(self) -> Self:
        return self

    def __stream__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]:
        for sse_event in self._raw_stream:
            events_to_fire = self._state.handle_event(sse_event)
            for event in events_to_fire:
                if self._starting_after is None or event.sequence_number > self._starting_after:
                    yield event

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        self.close()

    def close(self) -> None:
        """
        Close the response and release the connection.

        Automatically called if the response body is read to completion.
        """
        self._response.close()

    def get_final_response(self) -> ParsedResponse[TextFormatT]:
        """Waits until the stream has been read to completion and returns
        the accumulated `ParsedResponse` object.
        """
        self.until_done()
        response = self._state._completed_response
        if not response:
            raise RuntimeError("Didn't receive a `response.completed` event.")

        return response

    def until_done(self) -> Self:
        """Blocks until the stream has been consumed."""
        consume_sync_iterator(self)
        return self


class ResponseStreamManager(Generic[TextFormatT]):
    def __init__(
        self,
        api_request: Callable[[], Stream[RawResponseStreamEvent]],
        *,
        text_format: type[TextFormatT] | NotGiven,
        input_tools: Iterable[ToolParam] | NotGiven,
        starting_after: int | None,
    ) -> None:
        self.__stream: ResponseStream[TextFormatT] | None = None
        self.__api_request = api_request
        self.__text_format = text_format
        self.__input_tools = input_tools
        self.__starting_after = starting_after

    def __enter__(self) -> ResponseStream[TextFormatT]:
        raw_stream = self.__api_request()

        self.__stream = ResponseStream(
            raw_stream=raw_stream,
            text_format=self.__text_format,
            input_tools=self.__input_tools,
            starting_after=self.__starting_after,
        )

        return self.__stream

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        if self.__stream is not None:
            self.__stream.close()


class AsyncResponseStream(Generic[TextFormatT]):
    def __init__(
        self,
        *,
        raw_stream: AsyncStream[RawResponseStreamEvent],
        text_format: type[TextFormatT] | NotGiven,
        input_tools: Iterable[ToolParam] | NotGiven,
        starting_after: int | None,
    ) -> None:
        self._raw_stream = raw_stream
        self._response = raw_stream.response
        self._iterator = self.__stream__()
        self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
        self._starting_after = starting_after

    async def __anext__(self) -> ResponseStreamEvent[TextFormatT]:
        return await self._iterator.__anext__()

    async def __aiter__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]:
        async for item in self._iterator:
            yield item

    async def __stream__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]:
        async for sse_event in self._raw_stream:
            events_to_fire = self._state.handle_event(sse_event)
            for event in events_to_fire:
                if self._starting_after is None or event.sequence_number > self._starting_after:
                    yield event

    async def __aenter__(self) -> Self:
        return self

    async def __aexit__(
        self,
        exc_type: type[BaseException] | None,
        exc: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        await self.close()

    async def close(self) -> None:
        """
        Close the response and release the connection.

        Automatically called if the response body is read to completion.
        """
        await self._response.aclose()

    async def get_final_response(self) -> ParsedResponse[TextFormatT]:
        """Waits until the stream has been read to completion and returns
        the accumulated `ParsedResponse` object.
        """
        await self.until_done()
        response = self._state._completed_response
        if not response:
            raise RuntimeError("Didn't receive a `response.completed` event.")

        return response

    async def until_done(self) -> Self:
        """Blocks until the stream has been consumed."""
        await consume_async_iterator(self)
        return self


class AsyncResponseStreamManager(Generic[TextFormatT]):
    def __init__(
        self,
        api_request: Awaitable[AsyncStream[RawResponseStreamEvent]],
        *,
        text_format: type[TextFormatT] | NotGiven,
        input_tools: Iterable[ToolParam] | NotGiven,
        starting_after: int | None,
    ) -> None:
        self.__stream: AsyncResponseStream[TextFormatT] | None = None
        self.__api_request = api_request
        self.__text_format = text_format
        self.__input_tools = input_tools
        self.__starting_after = starting_after

    async def __aenter__(self) -> AsyncResponseStream[TextFormatT]:
        raw_stream = await self.__api_request

        self.__stream = AsyncResponseStream(
            raw_stream=raw_stream,
            text_format=self.__text_format,
            input_tools=self.__input_tools,
            starting_after=self.__starting_after,
        )

        return self.__stream

    async def __aexit__(
        self,
        exc_type: type[BaseException] | None,
        exc: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        if self.__stream is not None:
            await self.__stream.close()


class ResponseStreamState(Generic[TextFormatT]):
    def __init__(
        self,
        *,
        input_tools: Iterable[ToolParam] | NotGiven,
        text_format: type[TextFormatT] | NotGiven,
    ) -> None:
        self.__current_snapshot: ParsedResponseSnapshot | None = None
        self._completed_response: ParsedResponse[TextFormatT] | None = None
        self._input_tools = [tool for tool in input_tools] if is_given(input_tools) else []
        self._text_format = text_format
        self._rich_text_format: type | NotGiven = text_format if inspect.isclass(text_format) else NOT_GIVEN

    def handle_event(self, event: RawResponseStreamEvent) -> List[ResponseStreamEvent[TextFormatT]]:
        self.__current_snapshot = snapshot = self.accumulate_event(event)

        events: List[ResponseStreamEvent[TextFormatT]] = []

        if event.type == "response.output_text.delta":
            output = snapshot.output[event.output_index]
            assert output.type == "message"

            content = output.content[event.content_index]
            assert content.type == "output_text"

            events.append(
                build(
                    ResponseTextDeltaEvent,
                    content_index=event.content_index,
                    delta=event.delta,
                    item_id=event.item_id,
                    output_index=event.output_index,
                    sequence_number=event.sequence_number,
                    type="response.output_text.delta",
                    snapshot=content.text,
                )
            )
        elif event.type == "response.output_text.done":
            output = snapshot.output[event.output_index]
            assert output.type == "message"

            content = output.content[event.content_index]
            assert content.type == "output_text"

            events.append(
                build(
                    ResponseTextDoneEvent[TextFormatT],
                    content_index=event.content_index,
                    item_id=event.item_id,
                    output_index=event.output_index,
                    sequence_number=event.sequence_number,
                    type="response.output_text.done",
                    text=event.text,
                    parsed=parse_text(event.text, text_format=self._text_format),
                )
            )
        elif event.type == "response.function_call_arguments.delta":
            output = snapshot.output[event.output_index]
            assert output.type == "function_call"

            events.append(
                build(
                    ResponseFunctionCallArgumentsDeltaEvent,
                    delta=event.delta,
                    item_id=event.item_id,
                    output_index=event.output_index,
                    sequence_number=event.sequence_number,
                    type="response.function_call_arguments.delta",
                    snapshot=output.arguments,
                )
            )

        elif event.type == "response.completed":
            response = self._completed_response
            assert response is not None

            events.append(
                build(
                    ResponseCompletedEvent,
                    sequence_number=event.sequence_number,
                    type="response.completed",
                    response=response,
                )
            )
        else:
            events.append(event)

        return events

    def accumulate_event(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot:
        snapshot = self.__current_snapshot
        if snapshot is None:
            return self._create_initial_response(event)

        if event.type == "response.output_item.added":
            if event.item.type == "function_call":
                snapshot.output.append(
                    construct_type_unchecked(
                        type_=cast(Any, ParsedResponseFunctionToolCall), value=event.item.to_dict()
                    )
                )
            elif event.item.type == "message":
                snapshot.output.append(
                    construct_type_unchecked(type_=cast(Any, ParsedResponseOutputMessage), value=event.item.to_dict())
                )
            else:
                snapshot.output.append(event.item)
        elif event.type == "response.content_part.added":
            output = snapshot.output[event.output_index]
            if output.type == "message":
                output.content.append(
                    construct_type_unchecked(type_=cast(Any, ParsedContent), value=event.part.to_dict())
                )
        elif event.type == "response.output_text.delta":
            output = snapshot.output[event.output_index]
            if output.type == "message":
                content = output.content[event.content_index]
                assert content.type == "output_text"
                content.text += event.delta
        elif event.type == "response.function_call_arguments.delta":
            output = snapshot.output[event.output_index]
            if output.type == "function_call":
                output.arguments += event.delta
        elif event.type == "response.completed":
            self._completed_response = parse_response(
                text_format=self._text_format,
                response=event.response,
                input_tools=self._input_tools,
            )

        return snapshot

    def _create_initial_response(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot:
        if event.type != "response.created":
            raise RuntimeError(f"Expected to have received `response.created` before `{event.type}`")

        return construct_type_unchecked(type_=ParsedResponseSnapshot, value=event.response.to_dict())
