"""4-stage screening pipeline orchestrator."""

from __future__ import annotations
import asyncio
import json
import logging
import re
import threading
from typing import Any, Callable
import pandas as pd

from .labelling import get_labelling_scores
from .reasoning import get_rating_reasoning
from .clustering import get_reasoning_clusters
from .cluster_selection import get_paper_clusters
from .constants import CLUSTERING_MODEL_NAME

logger = logging.getLogger(__name__)


class ScreeningError(Exception):
    """User-friendly screening error with category."""

    def __init__(self, message: str, category: str = "unknown"):
        super().__init__(message)
        self.category = category

    @classmethod
    def from_exception(cls, exc: Exception) -> "ScreeningError":
        exc_type = type(exc).__name__
        if "auth" in exc_type.lower() or "401" in str(exc):
            return cls("OpenAI authentication failed. Check your API key.", category="auth")
        if "ratelimit" in exc_type.lower() or "429" in str(exc):
            return cls("OpenAI rate limit exceeded. Wait and retry.", category="rate_limit")
        if any(k in exc_type.lower() for k in ("timeout", "connection")):
            return cls("Network error connecting to OpenAI.", category="network")
        clean_msg = re.sub(r"sk-[A-Za-z0-9_-]{10,}", "sk-***", str(exc))
        return cls(f"Screening failed: {clean_msg}", category="unknown")


def convert_criteria_format(criteria_list: list[dict]) -> dict[str, dict[str, list[str]]]:
    """Convert flat criteria list to nested dict for pipeline stages."""
    inc_excl_dict: dict[str, dict[str, list[str]]] = {}
    for criterion in criteria_list:
        category = criterion.get("name", "Other")
        ctype = criterion.get("type", "exclude")
        value = str(criterion.get("value", ""))
        if not value:
            continue
        if category not in inc_excl_dict:
            inc_excl_dict[category] = {"include": [], "exclude": []}
        inc_excl_dict[category][ctype].append(value)
    return inc_excl_dict


async def screen_papers(
    *,
    papers_df: pd.DataFrame,
    criteria: list[dict],
    questions: list[str],
    model_name: str = "gpt-5-nano",
    clustering_model: str = CLUSTERING_MODEL_NAME,
    repetitions: int = 5,
    threshold: float = 1.0,
    clusters_type: str | None = None,
    progress_callback: Callable[[int, int, str], None] | None = None,
    cancel_event: threading.Event | None = None,
) -> tuple[pd.DataFrame, list[dict[str, Any]]]:
    """
    Run the full 4-stage screening pipeline.

    Returns (results_df, clusters_list).
    """
    inc_excl_dict = convert_criteria_format(criteria)
    total_items = len(papers_df) * 4
    current = 0

    # Stage 1: Labelling
    if progress_callback:
        progress_callback(current, total_items, "Scoring papers")
    if cancel_event and cancel_event.is_set():
        return papers_df, []

    df = await asyncio.to_thread(
        get_labelling_scores,
        papers_df,
        model_name,
        questions,
        inc_excl_dict,
        repetitions,
        cancel_event=cancel_event,
    )
    current += len(papers_df)
    if progress_callback:
        progress_callback(current, total_items, "Scoring papers")

    # Stage 2: Reasoning
    if cancel_event and cancel_event.is_set():
        return df, []
    if progress_callback:
        progress_callback(current, total_items, "Generating explanations")

    df = await asyncio.to_thread(
        get_rating_reasoning,
        df,
        model_name,
        questions,
        inc_excl_dict,
        cancel_event=cancel_event,
    )
    current += len(papers_df)
    if progress_callback:
        progress_callback(current, total_items, "Generating explanations")

    # Stage 3: Clustering
    if cancel_event and cancel_event.is_set():
        return df, []
    if progress_callback:
        progress_callback(current, total_items, "Clustering themes")

    include_reasoning = [
        row["rating_reasoning"]
        for _, row in df.iterrows()
        if row.get("mean_score") is not None and row["mean_score"] > threshold and row.get("rating_reasoning")
    ]
    exclude_reasoning = [
        row["rating_reasoning"]
        for _, row in df.iterrows()
        if row.get("mean_score") is not None and row["mean_score"] <= threshold and row.get("rating_reasoning")
    ]

    clusters: list[dict] = []

    if include_reasoning and clusters_type != "exclude":
        try:
            inc_clusters = await asyncio.to_thread(
                get_reasoning_clusters,
                include_reasoning,
                clustering_model,
                inc_excl_dict,
                "included",
                cancel_event=cancel_event,
            )
            for c in inc_clusters:
                clusters.append(
                    {
                        "cluster_type": "include",
                        "cluster_name": c["cluster_name"],
                        "cluster_description": c["cluster_description"],
                        "related_criteria": c.get("related_criteria", "Other"),
                    }
                )
        except Exception as e:
            logger.error(f"Include clustering failed: {e}")

    if exclude_reasoning and clusters_type != "include":
        try:
            exc_clusters = await asyncio.to_thread(
                get_reasoning_clusters,
                exclude_reasoning,
                clustering_model,
                inc_excl_dict,
                "excluded",
                cancel_event=cancel_event,
            )
            for c in exc_clusters:
                clusters.append(
                    {
                        "cluster_type": "exclude",
                        "cluster_name": c["cluster_name"],
                        "cluster_description": c["cluster_description"],
                        "related_criteria": c.get("related_criteria", "Other"),
                    }
                )
        except Exception as e:
            logger.error(f"Exclude clustering failed: {e}")

    current += len(papers_df)
    if progress_callback:
        progress_callback(current, total_items, "Clustering themes")

    # Stage 4: Cluster Assignment
    if cancel_event and cancel_event.is_set():
        return df, clusters
    if progress_callback:
        progress_callback(current, total_items, "Assigning clusters")

    if clusters:
        include_cluster_list = [c for c in clusters if c["cluster_type"] == "include"]
        exclude_cluster_list = [c for c in clusters if c["cluster_type"] == "exclude"]

        if include_cluster_list:
            df = await asyncio.to_thread(
                get_paper_clusters,
                df,
                model_name,
                include_cluster_list,
                "included",
                cancel_event=cancel_event,
            )
        if exclude_cluster_list:
            df = await asyncio.to_thread(
                get_paper_clusters,
                df,
                model_name,
                exclude_cluster_list,
                "excluded",
                cancel_event=cancel_event,
            )

    # Build results
    df["ai_decision"] = df["mean_score"].apply(lambda x: "include" if x is not None and x > threshold else "exclude")

    def _extract_clusters(row):
        result_clusters = []
        decision = row.get("ai_decision", "")
        col = "included_cluster_list" if decision == "include" else "excluded_cluster_list"
        raw = row.get(col)
        if isinstance(raw, str):
            try:
                parsed = json.loads(raw)
                if isinstance(parsed, list):
                    result_clusters.extend(parsed)
            except (json.JSONDecodeError, TypeError):
                pass
        return result_clusters

    df["assigned_clusters"] = df.apply(_extract_clusters, axis=1)

    current += len(papers_df)
    if progress_callback:
        progress_callback(current, total_items, "Complete")

    return df, clusters
