from typing import List
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
import openai
from openai import OpenAI
import boto3
import os
import time

from config import openai_api_key
from db import database  # ✅ DB 연결

router = APIRouter()
client = OpenAI(api_key=openai_api_key)
openai.api_key = openai_api_key

s3_client = boto3.client('s3')
bucket_name = 'shanri-ai-chatbot-for-text-to-speech'

# 📌 모델
class Message(BaseModel):
    role: str
    content: str

class ChatRequest(BaseModel):
    messages: List[Message]
    conversation_id: int | None = None

class SpeechText(BaseModel):
    text: str
    chat_token: str

# ✅ GPT 응답 생성 및 저장
@router.post("/sanwa/gpt/ask_question")
async def ask_question(chat: ChatRequest):
    if not chat.messages:
        raise HTTPException(status_code=400, detail="messages is required")

    user_message = chat.messages[-1].content.strip()
    title = user_message[:20]

    response = openai.chat.completions.create(
        model="gpt-4o",
        messages=[m.dict() for m in chat.messages]
    )
    answer = response.choices[0].message.content.strip()

    # 🔸 대화 생성 또는 기존 ID 사용
    if chat.conversation_id is None:
        await database.execute("""
            INSERT INTO conversations (title, is_deleted, created_at, updated_at)
            VALUES (:title, 0, NOW(), NOW())
        """, {"title": title})

        # ✅ 안전하게 id 조회 (동일한 title 중 최신 id 가져오기)
        conversation_id = await database.fetch_val("""
            SELECT id FROM conversations
            WHERE title = :title
            ORDER BY id DESC
            LIMIT 1
        """, {"title": title})
    else:
        conversation_id = chat.conversation_id

    # 🔸 sort_order 계산
    max_order = await database.fetch_val("""
        SELECT COALESCE(MAX(sort_order), 0) FROM messages
        WHERE conversation_id = :conversation_id
    """, {"conversation_id": conversation_id})

    # 🔸 메시지 저장
    insert_msg = """
        INSERT INTO messages (conversation_id, role, content, sort_order, is_deleted, created_at, updated_at)
        VALUES (:conversation_id, :role, :content, :sort_order, 0, NOW(), NOW())
    """
    await database.execute(insert_msg, {
        "conversation_id": conversation_id,
        "role": "user",
        "content": user_message,
        "sort_order": max_order + 1
    })
    await database.execute(insert_msg, {
        "conversation_id": conversation_id,
        "role": "assistant",
        "content": answer,
        "sort_order": max_order + 2
    })

    return {
        "conversation_id": conversation_id,
        "answer": answer
    }

# ✅ 대화 목록 조회
@router.get("/sanwa/gpt/conversations")
async def get_conversations():
    result = await database.fetch_all("""
        SELECT id, title, updated_at FROM conversations
        WHERE is_deleted = 0
        ORDER BY updated_at DESC
    """)
    return [dict(row) for row in result]

# ✅ 특정 대화의 메시지 가져오기
@router.get("/sanwa/gpt/messages/{conversation_id}")
async def get_conversation_messages(conversation_id: int):
    rows = await database.fetch_all("""
        SELECT role, content FROM messages
        WHERE conversation_id = :conversation_id AND is_deleted = 0
        ORDER BY sort_order ASC
    """, {"conversation_id": conversation_id})
    return [dict(row) for row in rows]

# ✅ 음성합성
@router.post("/sanwa/gpt/speech")
async def speech(speech_text: SpeechText):
    text = speech_text.text
    user_id = speech_text.chat_token
    if not text:
        raise HTTPException(status_code=400, detail="Text is required")

    response = client.audio.speech.create(
        model="tts-1",
        voice="nova",
        input=text,
    )

    audio_file = f"tmp/audio-{user_id}-{time.time()}.mp3"
    with open(audio_file, 'wb') as f:
        for chunk in response.iter_bytes():
            f.write(chunk)

    s3_key = f"{user_id}-{time.time()}.mp3"
    s3_client.upload_file(audio_file, bucket_name, s3_key)
    os.remove(audio_file)

    return {
        "audio_file": f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
    }

# ✅ 헬스체크
@router.get("/health")
async def health_check():
    return {"status": "healthy"}
