import asyncio
import threading
import sys
from fastapi import APIRouter, WebSocket, Depends, WebSocketDisconnect
from fastapi.security import OAuth2PasswordBearer
from google.cloud.speech import RecognitionConfig, StreamingRecognitionConfig
import janus
import queue
from google.cloud import speech
import os
from fastapi import FastAPI
import logging

os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'medical_memo.json'
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

app = FastAPI()


logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

RATE = 16000
CHUNK = int(RATE / 10)
config = RecognitionConfig(
    encoding=RecognitionConfig.AudioEncoding.LINEAR16,
    sample_rate_hertz=16000,
    audio_channel_count=1,
    use_enhanced=True,
    language_code="ja-JP",
    model="command_and_search",
    enable_spoken_punctuation=True
)
streaming_config = StreamingRecognitionConfig(config=config, interim_results=True)

@app.get("/health")
async def health_check():
    return {"status": "ok"}

class SpeechClientBridge:
    def __init__(self, streaming_config, on_response):
        self._on_response = on_response
        self._queue = queue.Queue()
        self._ended = False
        self.streaming_config = streaming_config

    def start(self, q, thread):
        self._ended = False
        self.q = q
        self.thread = thread
        client = speech.SpeechClient()
        stream = self.generator()
        requests = (
            speech.StreamingRecognizeRequest(audio_content=content)
            for content in stream
        )
        responses = client.streaming_recognize(self.streaming_config, requests)
        self.process_responses_loop(responses)
        while not self.q.empty():
            print(self.q.get())

    def terminate(self):
        self._ended = True

    def add_request(self, buffer):
        self._queue.put(bytes(buffer), block=False)

    def process_responses_loop(self, responses):
        for response in responses:
            self._on_response(response, self.q, self.thread)
            if self._ended:
                break

    def generator(self):
        while not self._ended:
            chunk = self._queue.get()
            if chunk is None:
                return
            data = [chunk]
            while True:
                try:
                    chunk = self._queue.get(block=False)
                    if chunk is None:
                        return
                    data.append(chunk)
                except queue.Empty:
                    break
            yield b"".join(data)


def on_transcription_response(response, q, websocket):
    num_chars_printed = 0
    if not response.results:
        return
    result = response.results[0]
    if not result.alternatives:
        return

    transcript = result.alternatives[0].transcript
    overwrite_chars = " " * (num_chars_printed - len(transcript))

    if not result.is_final:
        logging.info(f"Partial transcription: {transcript}")
        sys.stdout.write(transcript + overwrite_chars + "\r")
        sys.stdout.flush()
        num_chars_printed = len(transcript)
        q.put({"is_final": False, "transcript": transcript})
    else:
        print('==>' + transcript + overwrite_chars)
        logging.info(f"Final transcription: {transcript}")
        q.put({"is_final": True, "transcript": transcript})


async def send_transcription(websocket: WebSocket, async_q: janus._AsyncQueueProxy):
    try:
        while True:
            send_data = await async_q.get()
            await websocket.send_json({
                "is_final": send_data['is_final'],
                "transcript": send_data['transcript']
            })
    except asyncio.CancelledError:
        pass


@app.websocket("/ws/websocket")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    print("opened")
    loop = asyncio.get_event_loop()
    q = janus.Queue()
    bridge = SpeechClientBridge(streaming_config, on_transcription_response)
    bridge_thread = threading.Thread(target=bridge.start, args=(q.sync_q, websocket))
    bridge_thread.start()
    send_task = asyncio.create_task(send_transcription(websocket, q.async_q))
    try:
        while True:
            data = await websocket.receive()
            if 'text' in data:
                message = data['text']
                logging.info(f"Received text: {message}")
                await websocket.send_text(message)
            elif 'bytes' in data:
                message = data['bytes']
                logging.info(f"Received audio bytes: {len(message)} bytes")
                if message:
                    bridge.add_request(message)
            else:
                pass
    except WebSocketDisconnect:
        bridge.terminate()
        send_task.cancel()
        print("closed")
    except Exception as e:
        bridge.terminate()
        send_task.cancel()
        print(f"Exception: {e}")

