
import chromadb
from chromadb.utils import embedding_functions
import os
from sqlalchemy.orm import Session
from models import DictionaryItem, BibleQA
from rank_bm25 import BM25Okapi
from janome.tokenizer import Tokenizer

# ... (Existing Chroma Setup) ...

# Global BM25 Index
bm25_index = None
bm25_doc_map = {} # Maps BM25 index to {"id": 1, "question": "...", "answer": "..."}
tokenizer = Tokenizer()

def tokenize_japanese(text):
    """Extract nouns for BM25."""
    tokens = []
    for token in tokenizer.tokenize(text):
        pos = token.part_of_speech.split(',')[0]
        if pos in ['名詞']: # Only nouns
            tokens.append(token.surface)
    return tokens

def extract_query_details(text):
    """
    Extract tokens and specifically identify proper nouns (固有名詞).
    """
    tokens = []
    proper_nouns = []
    for token in tokenizer.tokenize(text):
        parts = token.part_of_speech.split(',')
        pos = parts[0]
        sub_pos = parts[1] if len(parts) > 1 else ''
        
        if pos in ['名詞']:
            tokens.append(token.surface)
            if sub_pos == '固有名詞': # Proper Noun (Person, Place, Organization)
                proper_nouns.append(token.surface)
    return tokens, proper_nouns

# ... (Existing Sync/Delete functions) ...

def rebuild_qa_index(db: Session):
    """
    Wipe and rebuild BibleQA ChromaDB and in-memory BM25 index.
    """
    items = db.query(BibleQA).all()
    
    # 1. Rebuild Chroma
    try:
        chroma_client.delete_collection("bible_qa")
    except:
        pass
        
    global qa_collection
    qa_collection = chroma_client.get_or_create_collection(name="bible_qa", embedding_function=embedding_fn)
    
    # 2. Rebuild BM25
    global bm25_index, bm25_doc_map
    bm25_doc_map = {}
    tokenized_corpus = []
    
    docs = []
    metas = []
    ids = []
    
    if items:
        for i, item in enumerate(items):
            # For Chroma
            doc_text = item.question
            docs.append(doc_text)
            metas.append({"answer": item.answer, "id": str(item.id)})
            ids.append(str(item.id))
            
            # For BM25
            tokens = tokenize_japanese(doc_text)
            tokenized_corpus.append(tokens)
            bm25_doc_map[i] = {"id": item.id, "question": item.question, "answer": item.answer}

        # Batch Add to Chroma
        if docs:
            qa_collection.add(documents=docs, metadatas=metas, ids=ids)
            
        # Init BM25
        if tokenized_corpus:
            bm25_index = BM25Okapi(tokenized_corpus)
        else:
            bm25_index = None
def init_bm25(db: Session):
    """
    Initialize BM25 index from DB without rebuilding Chroma.
    Call this on server startup.
    """
    global bm25_index, bm25_doc_map, tokenizer
    bm25_doc_map = {}
    tokenized_corpus = []
    
    items = db.query(BibleQA).all()
    if items:
        print(f"Initializing BM25 for {len(items)} items...")
        for i, item in enumerate(items):
             # For BM25
            tokens = tokenize_japanese(item.question)
            tokenized_corpus.append(tokens)
            # Store tokens for coverage check
            bm25_doc_map[i] = {"id": item.id, "question": item.question, "answer": item.answer, "tokens": tokens}
            
        if tokenized_corpus:
            bm25_index = BM25Okapi(tokenized_corpus)
            print("BM25 Index built successfully.")
        else:
            bm25_index = None
    else:
        print("No items for BM25.")

def search_bible_qa_hybrid(query: str, k: int = 5):
    """
    Hybrid Search: Vector (Chroma) + Keyword (BM25) with RRF Fusion.
    Includes Keyword Coverage Check (Majority Rule).
    """
    k_retrieval = 10 # Retrieve top 10 from each to fuse
    
    # 1. Vector Search
    vector_results = qa_collection.query(query_texts=[query], n_results=k_retrieval)
    
    vector_ranks = {} # id -> rank (1-based)
    if vector_results['ids']:
        for rank, item_id in enumerate(vector_results['ids'][0]):
            vector_ranks[int(item_id)] = rank + 1

    # 2. BM25 Search
    bm25_ranks = {} # id -> rank (1-based)
    query_tokens = []
    proper_nouns = []
    
    if bm25_index:
        query_tokens, proper_nouns = extract_query_details(query)
        # Get scores
        doc_scores = bm25_index.get_scores(query_tokens)
        # Sort indices by score desc
        top_n_indices = sorted(range(len(doc_scores)), key=lambda i: doc_scores[i], reverse=True)[:k_retrieval]
        
        for rank, idx in enumerate(top_n_indices):
            if doc_scores[idx] > 0: # Only consider positive matches
                item_data = bm25_doc_map.get(idx)
                if item_data:
                    bm25_ranks[item_data['id']] = rank + 1
                    
    # 3. RRF Fusion
    # Score = 1 / (60 + rank) - constant 60 is standard in RRF
    fused_scores = {}
    
    all_ids = set(vector_ranks.keys()) | set(bm25_ranks.keys())
    
    # Pre-fetch doc data for coverage check? We need to look up by ID.
    # Map ID -> Tokens
    id_to_tokens_map = {v['id']: v.get('tokens', []) for k, v in bm25_doc_map.items()} if bm25_index else {}

    for item_id in all_ids:
        # Calculate scores
        v_score = 1 / (60 + vector_ranks[item_id]) if item_id in vector_ranks else 0
        b_score = 1 / (60 + bm25_ranks[item_id]) if item_id in bm25_ranks else 0
        
        # KEYWORD COVERAGE CHECK (Majority Rule)
        coverage = 0.0
        if query_tokens:
            start_tokens = set(query_tokens)
            doc_tokens = set(id_to_tokens_map.get(item_id, []))
            # Calculate match
            common = start_tokens & doc_tokens
            if len(start_tokens) > 0: # Avoid division by zero if query_tokens is empty
                coverage = len(common) / len(start_tokens)
            
            # MAJORITY RULE: If coverage < 0.5 (less than half keywords matched), apply HEAVY penalty.
            if coverage < 0.5:
                # User wants "Strict Majority".
                # If coverage < 0.5, it's likely a partial match (e.g. Noah found, Paul missing).
                # Penalize significantly so it drops below threshold or ranks lower.
                v_score *= 0.01 # Almost 0
                b_score *= 0.01 # Almost 0 (if any)
                
        # STRICT PROPER NOUN CHECK
        # If query has proper nouns (Person names etc.), ALL of them MUST be present in the document.
        if proper_nouns:
            doc_tokens = set(id_to_tokens_map.get(item_id, []))
            # Check if all proper nouns are in doc
            if not set(proper_nouns).issubset(doc_tokens):
                # Missing a proper noun is a critical mismatch
                v_score *= 0.001
                b_score *= 0.001
                
        # Existing Penalty (Fallback) - Optional now since Coverage handles it, but good to keep for no-match cases
        # If Vector found it but BM25 score is 0, Coverage is 0. So handled above.
        
        fused_scores[item_id] = v_score + b_score
        
    # Sort by fused score desc
    sorted_ids = sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)[:k]
    
    # Fetch final results details
    final_results = []
    
    for item_id in sorted_ids:
        # Better: Since we only return top K, just finding them is fine.
        question = ""
        answer = ""
        
        # Find in map (Linear scan is okay for now since K is small, but for production id map is better)
        # Let's make bm25_doc_map map ID to data? No, BM25 needs positional index.
        # We can create a quick ID lookup map.
        item_data = next((v for k, v in bm25_doc_map.items() if v['id'] == item_id), None)
        
        if item_data:
            question = item_data['question']
            answer = item_data['answer']
            
        # For 'distance', RRF score is inverse of distance conceptualy.
        # Let's return the RRF score as 'distance' but note it's score (higher is better).
        # OR we invert it to keep compatible with 'lower is better' logic of vector?
        # The user's checking page expects 'distance'.
        # Let's return 1 - score (normalized) or just the raw score.
        # Vector distance (L2) is usually < 1.0 for matches.
        # RRF score max is roughly 1/61 + 1/61 ~= 0.03.
        # Let's just return the Score and handle it in frontend/router.
        
        # Recalculate coverage for reporting
        final_coverage = 0.0
        if query_tokens and item_id in id_to_tokens_map:
             common = set(query_tokens) & set(id_to_tokens_map[item_id])
             if len(set(query_tokens)) > 0:
                 final_coverage = len(common) / len(set(query_tokens))

        final_results.append({
            "id": str(item_id),
            "score": fused_scores[item_id], # Higher is better
            "coverage": final_coverage,     # Report coverage
            "metadata": {"question": question, "answer": answer}
        })
        
    return final_results

# ChromaDB Setup
# Persist data in a folder named 'chroma_db'
CHROMA_DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "chroma_db")
chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)

# Embedding Function
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")

if GEMINI_API_KEY:
    try:
        embedding_fn = embedding_functions.GoogleGenerativeAiEmbeddingFunction(api_key=GEMINI_API_KEY)
        print("Using Gemini Embeddings")
    except Exception as e:
        print(f"Failed to init Gemini Embeddings: {e}")
        embedding_fn = embedding_functions.DefaultEmbeddingFunction()
else:
    print("GEMINI_API_KEY not found, using default embeddings.")
    embedding_fn = embedding_functions.DefaultEmbeddingFunction()

collection = chroma_client.get_or_create_collection(name="bible_dictionary", embedding_function=embedding_fn)
qa_collection = chroma_client.get_or_create_collection(name="bible_qa", embedding_function=embedding_fn)

def sync_item(item: DictionaryItem):
    """
    Upsert a single item to ChromaDB.
    """
    # Text to embed: "Term: Definition (Verses)"
    doc_text = f"{item.term}: {item.definition} ({item.verses})"
    
    collection.upsert(
        documents=[doc_text],
        metadatas=[{"term": item.term, "verses": item.verses}],
        ids=[str(item.term)] # Use term as ID for easy deduplication
    )

def delete_item(term: str):
    """
    Delete an item from ChromaDB.
    """
    try:
        collection.delete(ids=[term])
    except:
        pass # Ignore if not found

def rebuild_index(db: Session):
    """
    Wipe and rebuild ChromaDB from MySQL.
    Useful for 'Reset' or 'Bulk Import' features.
    """
    # 1. Get all items from MySQL
    items = db.query(DictionaryItem).all()
    
    # 2. Reset Collection (Delete collection and recreate is often cleaner, or just delete all)
    # Chroma doesn't have a fast 'truncate', so we delete by ID or just recreate.
    # Recreating collection:
    try:
        chroma_client.delete_collection("bible_dictionary")
    except:
        pass
    
    global collection
    collection = chroma_client.get_or_create_collection(name="bible_dictionary")
    
    if not items:
        return

    # 3. Batch Add
    docs = []
    metas = []
    ids = []
    
    for item in items:
        doc_text = f"{item.term}: {item.definition} ({item.verses})"
        docs.append(doc_text)
        metas.append({"term": item.term, "verses": item.verses or ""})
        ids.append(str(item.term))
    
    # Add in batches if large (Chroma handles small batches defined by user, but 5k is usually fine)
    if docs:
        collection.add(documents=docs, metadatas=metas, ids=ids)

def search_dictionary(query: str, k: int = 3):
    """
    Search relevant dictionary items.
    """
    results = collection.query(
        query_texts=[query],
        n_results=k
    )
    
    # Format results
    # results['documents'][0] is a list of strings
    if not results['documents']:
        return []
    
    return results['documents'][0]

# --- BibleQA Chroma Logic ---

def sync_qa_item(item: BibleQA):
    """
    Upsert a single BibleQA item to ChromaDB.
    """
    # Text to embed: item.question
    # We primarily search against the Question.
    doc_text = item.question
    
    qa_collection.upsert(
        documents=[doc_text],
        metadatas=[{"answer": item.answer, "id": str(item.id)}],
        ids=[str(item.id)]
    )

def delete_qa_item(item_id: int):
    try:
        qa_collection.delete(ids=[str(item_id)])
    except:
        pass



def search_bible_qa(query: str, k: int = 1):
    """
    Search relevant BibleQA items.
    Returns list of matching BibleQA items (dict or object structure).
    """
    results = qa_collection.query(
        query_texts=[query],
        n_results=k
    )
    
    # results['ids'][0] will give us the IDs of the matching questions
    # results['distances'][0] will give us similarity scores (lower is better for L2, higher is better for Cosine)
    # Chroma default is L2 (Euclidean). Lower is closer.
    
    if not results['ids']:
        return []

    # Filter by threshold if needed, but for now return top match
    # Structure: [{"id": "1", "distance": 0.2, "metadata": {...}}]
    
    matches = []
    for i in range(len(results['ids'][0])):
        matches.append({
            "id": results['ids'][0][i],
            "distance": results['distances'][0][i],
            "metadata": results['metadatas'][0][i]
        })
        
    return matches
