import asyncio
import threading
import sys
from fastapi import APIRouter, WebSocket, Depends, WebSocketDisconnect
from fastapi.security import OAuth2PasswordBearer
from google.cloud.speech import RecognitionConfig, StreamingRecognitionConfig
import janus
import queue
from google.cloud import speech
import os
from fastapi import FastAPI
import logging
import json

os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'medical_memo.json'
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

app = FastAPI()


logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

RATE = 48000
CHUNK = int(RATE / 10)
config = RecognitionConfig(
    encoding=RecognitionConfig.AudioEncoding.LINEAR16,
    sample_rate_hertz=48000,
    audio_channel_count=1,
    use_enhanced=True,
    language_code="ja-JP",
    # language_code="zh-CN",
    model="command_and_search",
    enable_spoken_punctuation=True
)
streaming_config = StreamingRecognitionConfig(config=config, interim_results=True)

@app.get("/health")
async def health_check():
    return {"status": "ok"}

class SpeechClientBridge:
    def __init__(self, streaming_config, on_response):
        self._on_response = on_response
        self._queue = queue.Queue()
        self._ended = False
        self.streaming_config = streaming_config

    def start(self, q, thread):
        self._ended = False
        self.q = q
        self.thread = thread
        client = speech.SpeechClient()
        stream = self.generator()
        requests = (
            speech.StreamingRecognizeRequest(audio_content=content)
            for content in stream
        )
        responses = client.streaming_recognize(self.streaming_config, requests)
        self.process_responses_loop(responses)
        while not self.q.empty():
            print(self.q.get())

    def terminate(self):
        self._ended = True

    def add_request(self, buffer):
        self._queue.put(bytes(buffer), block=False)

    def process_responses_loop(self, responses):
        for response in responses:
            self._on_response(response, self.q, self.thread)
            if self._ended:
                break

    def generator(self):
        while not self._ended:
            chunk = self._queue.get()
            if chunk is None:
                return
            data = [chunk]
            while True:
                try:
                    chunk = self._queue.get(block=False)
                    if chunk is None:
                        return
                    data.append(chunk)
                except queue.Empty:
                    break
            yield b"".join(data)


def on_transcription_response(response, q, websocket):
    num_chars_printed = 0
    if not response.results:
        return
    result = response.results[0]
    if not result.alternatives:
        return

    transcript = result.alternatives[0].transcript
    overwrite_chars = " " * (num_chars_printed - len(transcript))

    if not result.is_final:
        logging.info(f"Partial transcription: {transcript}")
        sys.stdout.write(transcript + overwrite_chars + "\r")
        sys.stdout.flush()
        num_chars_printed = len(transcript)
        q.put({"is_final": False, "transcript": transcript})
    else:
        print('==>' + transcript + overwrite_chars)
        logging.info(f"Final transcription: {transcript}")
        q.put({"is_final": True, "transcript": transcript})


async def send_transcription(websocket: WebSocket, async_q: janus._AsyncQueueProxy):
    try:
        while True:
            send_data = await async_q.get()
            
            # 🔻 로그 추가 1: 전송 직전
            logging.info(f"Sending to client: {send_data}")
            
            await websocket.send_json({
                "is_final": send_data['is_final'],
                "transcript": send_data['transcript']
            })
            
            # 🔻 로그 추가 2: 전송 성공 후
            logging.info(f"Successfully sent: {send_data['is_final']}")
            
            await asyncio.sleep(0.1)
            
    except asyncio.CancelledError:
        # 🔻 로그 추가 3: 취소될 때
        logging.warning("Send task was cancelled.")
        pass
    except Exception as e:
        # 🔻 로그 추가 4: 기타 예외 발생 시
        logging.error(f"Error in send_transcription: {e}", exc_info=True)


@app.websocket("/wslang/zh/")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    # print("opened") # <--- logging.info로 대체
    logging.info("WebSocket connection opened.")
    
    q = janus.Queue()
    bridge = SpeechClientBridge(streaming_config, on_transcription_response)
    
    # bridge.start를 위한 큐(q.sync_q)와 웹소켓(websocket)을 전달합니다.
    asyncio.create_task(asyncio.to_thread(bridge.start, q.sync_q, websocket))
    
    send_task = asyncio.create_task(send_transcription(websocket, q.async_q))
    
    try:
        while True:
            data = await websocket.receive()
            
            if 'text' in data:
                message = data['text']
                logging.info(f"Received text message: {message}")
                
                # ▼▼▼▼▼ 여기부터 수정/추가된 로직 ▼▼▼▼▼
                try:
                    # 클라이언트가 보낸 텍스트가 JSON인지 확인
                    json_data = json.loads(message)
                    
                    # 클라이언트가 "나 끝났어" 신호를 보냈는지 확인
                    if json_data.get('final') == True:
                        logging.info("Client confirmed 'final: True'. Closing connection.")
                        break # <-- while 루프를 탈출하여 finally 블록으로 이동
                        
                except json.JSONDecodeError:
                    # 일반 텍스트 메시지(지금은 사용하지 않음)
                    logging.warning(f"Received non-JSON text: {message}")
                    pass
                except Exception as e:
                    logging.error(f"Error processing client text message: {e}")
                # ▲▲▲▲▲ 여기까지 수정/추가된 로직 ▲▲▲▲▲

            elif 'bytes' in data:
                message = data['bytes']
                # logging.info(f"Received audio bytes: {len(message)} bytes") # <--- 너무 많으니 주석 처리 권장
                if message:
                    bridge.add_request(message)
            else:
                pass
                
    except WebSocketDisconnect:
        logging.info("Client disconnected.")
    except Exception as e:
        # print(f"Exception: {e}") # <--- logging.error로 대체
        logging.error(f"Exception in websocket_endpoint: {e}", exc_info=True)
        
    finally:
        # ▼▼▼▼▼ 연결 종료 로직 (중요) ▼▼▼▼▼
        logging.info("Cleaning up resources...")
        bridge.terminate()
        send_task.cancel()
        
        # Google STT 스레드(bridge.start)가 안전하게 종료되도록 
        # 큐(Queue)에 'None'을 넣어 'generator' 루프를 끝냅니다.
        try:
            q.sync_q.put(None)
        except queue.Full:
            pass # 큐가 꽉 차도 무시

        logging.info("WebSocket connection closed.")

