"""Screening API endpoints with async job pattern and DB persistence."""

from __future__ import annotations

import json
import logging
import math
import os
import time
import uuid
from datetime import datetime, timezone
from typing import Any

import pandas as pd
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request

from api.schemas.screening import (
    ScreeningCostEstimateRequest,
    ScreeningCostEstimateResponse,
    ScreeningJobListItem,
    ScreeningJobResponse,
    ScreeningRequest,
    ScreeningResultResponse,
)
from crystallise.llm.cost import estimate_cost

logger = logging.getLogger(__name__)

router = APIRouter()

# ── DB helpers ──

_CREATE_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS screening_jobs (
    id TEXT PRIMARY KEY,
    project_id INTEGER,
    status TEXT NOT NULL DEFAULT 'pending',
    progress REAL NOT NULL DEFAULT 0.0,
    stage TEXT NOT NULL DEFAULT '',
    config TEXT NOT NULL DEFAULT '{}',
    results TEXT,
    clusters TEXT,
    stage_timings TEXT DEFAULT '{}',
    duration_ms INTEGER,
    estimated_cost_usd REAL,
    error TEXT,
    error_category TEXT,
    error_retryable INTEGER,
    model_version TEXT,
    created_at TEXT NOT NULL,
    completed_at TEXT
)
"""

_table_created = False
_db_available = True  # Set to False if DB init fails (e.g. in tests)


def _get_db():
    from crystallise.db.backend import get_backend

    return get_backend()


_ALLOWED_MIGRATION_COLUMNS: frozenset[tuple[str, str]] = frozenset(
    {
        ("project_id", "INTEGER"),
        ("error_category", "TEXT"),
        ("error_retryable", "INTEGER"),
        ("model_version", "TEXT"),
    }
)


def _ensure_table():
    global _table_created, _db_available
    if _table_created or not _db_available:
        return
    try:
        db = _get_db()
        with db.get_connection() as conn:
            if not db.table_exists(conn, "screening_jobs"):
                db.executescript(conn, _CREATE_TABLE_SQL)
            else:
                # Migrate: add columns if missing (use SAVEPOINT for PostgreSQL compatibility)
                for col, col_type in _ALLOWED_MIGRATION_COLUMNS:
                    if (col, col_type) not in _ALLOWED_MIGRATION_COLUMNS:
                        raise ValueError(f"Disallowed migration column: {col!r} {col_type!r}")
                    try:
                        db.execute(conn, f"SAVEPOINT sp_{col}")
                        db.execute(conn, f"ALTER TABLE screening_jobs ADD COLUMN {col} {col_type}")
                        db.execute(conn, f"RELEASE SAVEPOINT sp_{col}")
                    except Exception:
                        db.execute(conn, f"ROLLBACK TO SAVEPOINT sp_{col}")  # column already exists
        _table_created = True
    except Exception:
        logger.warning(
            "Database init failed for screening_jobs; running without persistence",
            exc_info=True,
        )
        _db_available = False


def _scrub_nan(obj: Any) -> Any:
    """Recursively replace NaN/Inf floats with None.

    A NaN nested inside a list-valued column (e.g. `score_list = [0.5, nan, 0.7]`
    when a repetition fails) survives `to_dict('records')` and crashes
    `json.dumps`. Ported from crystallise-master commit 3430709, which fixed
    a 9.2h prod run that silently lost its results to this bug.
    """
    if isinstance(obj, float):
        return None if (math.isnan(obj) or math.isinf(obj)) else obj
    if isinstance(obj, dict):
        return {k: _scrub_nan(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_scrub_nan(v) for v in obj]
    return obj


def _save_job(job: dict[str, Any]):
    """Upsert a screening job to the database. No-op if DB unavailable."""
    _ensure_table()
    if not _db_available:
        return
    db = _get_db()
    params = (
        job["id"],
        job.get("project_id"),
        job["status"],
        job["progress"],
        job["stage"],
        json.dumps(_scrub_nan(job.get("config", {}))),
        json.dumps(_scrub_nan(job["results"])) if job.get("results") is not None else None,
        json.dumps(_scrub_nan(job["clusters"])) if job.get("clusters") is not None else None,
        json.dumps(_scrub_nan(job.get("stage_timings", {}))),
        job.get("duration_ms"),
        job.get("estimated_cost_usd"),
        job.get("error"),
        job.get("error_category"),
        1 if job.get("error_retryable") else (0 if job.get("error_retryable") is not None else None),
        job.get("model_version"),
        job["created_at"],
        job.get("completed_at"),
    )
    with db.get_connection() as conn:
        cursor = db.execute(
            conn,
            """UPDATE screening_jobs SET
                project_id=?, status=?, progress=?, stage=?, config=?,
                results=?, clusters=?, stage_timings=?, duration_ms=?,
                estimated_cost_usd=?, error=?, error_category=?,
                error_retryable=?, model_version=?,
                created_at=?, completed_at=?
               WHERE id=?""",
            params[1:] + (params[0],),
        )
        if cursor.rowcount == 0:
            db.execute(
                conn,
                """INSERT INTO screening_jobs
                   (id, project_id, status, progress, stage, config, results, clusters,
                    stage_timings, duration_ms, estimated_cost_usd, error,
                    error_category, error_retryable, model_version,
                    created_at, completed_at)
                   VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
                params,
            )


def _load_job(job_id: str) -> dict[str, Any] | None:
    """Load a screening job from the database. Returns None if DB unavailable."""
    _ensure_table()
    if not _db_available:
        return None
    db = _get_db()
    with db.get_connection() as conn:
        cursor = db.execute(conn, "SELECT * FROM screening_jobs WHERE id = ?", (job_id,))
        row = cursor.fetchone()
        if row is None:
            return None
        return _row_to_dict(row)


def _list_jobs(limit: int = 50) -> list[dict[str, Any]]:
    """List screening jobs ordered by creation time (newest first)."""
    _ensure_table()
    if not _db_available:
        return []
    db = _get_db()
    with db.get_connection() as conn:
        cursor = db.execute(
            conn,
            "SELECT * FROM screening_jobs ORDER BY created_at DESC LIMIT ?",
            (limit,),
        )
        return [_row_to_dict(row) for row in cursor.fetchall()]


def _load_latest_for_project(project_id: int) -> dict[str, Any] | None:
    """Load the latest completed screening job for a project."""
    _ensure_table()
    if not _db_available:
        return None
    db = _get_db()
    with db.get_connection() as conn:
        cursor = db.execute(
            conn,
            "SELECT * FROM screening_jobs WHERE project_id = ? AND status = 'completed' ORDER BY created_at DESC LIMIT 1",
            (project_id,),
        )
        row = cursor.fetchone()
        if row is None:
            return None
        return _row_to_dict(row)


def _row_to_dict(row) -> dict[str, Any]:
    """Convert a DB row (dict-like) to a job dict."""
    # DatabaseBackend returns dict-like rows via DictCursor
    r = dict(row)
    for json_field in ("config", "results", "clusters", "stage_timings"):
        if r.get(json_field) and isinstance(r[json_field], str):
            r[json_field] = json.loads(r[json_field])
    # Convert error_retryable from int to bool
    if r.get("error_retryable") is not None:
        r["error_retryable"] = bool(r["error_retryable"])
    return r


# ── In-memory cache for active jobs (fast polling) ──

_active_jobs: dict[str, dict[str, Any]] = {}


def _get_active_job_for_project(project_id: int) -> dict[str, Any] | None:
    """Return the currently active screening job for a project, if any.

    Pure read: only returns in-memory entries. DB rows that look 'running'
    without a matching in-memory entry are orphans (server restart, persist
    failure); they are reaped by `_cleanup_stale_running_jobs` on
    POST /screening/jobs, not here. Pre-fix, this function mutated DB state
    on a plain GET — see crystallise-master commit ef47ce3.
    """
    for job in _active_jobs.values():
        if job.get("project_id") == project_id and job["status"] in ("pending", "running"):
            return job
    return None


def _cleanup_stale_running_jobs(project_id: int) -> int:
    """Mark orphaned 'running' rows for a project as failed.

    An orphan is a screening_jobs row with status pending/running whose
    job_id is not in `_active_jobs` (the in-process registry of live jobs).
    Called from POST /screening/jobs so a stale orphan can't 409-block a
    legitimate retry.
    """
    _ensure_table()
    if not _db_available:
        return 0
    db = _get_db()
    active_ids = {
        jid
        for jid, j in _active_jobs.items()
        if j.get("project_id") == project_id
    }
    now_iso = datetime.now(timezone.utc).isoformat()
    reaped = 0
    with db.get_connection() as conn:
        cursor = db.execute(
            conn,
            "SELECT id FROM screening_jobs WHERE project_id = ? "
            "AND status IN ('pending', 'running')",
            (project_id,),
        )
        running_ids = [row[0] for row in cursor.fetchall()]
        for jid in running_ids:
            if jid in active_ids:
                continue
            db.execute(
                conn,
                """UPDATE screening_jobs
                   SET status = 'failed',
                       error = 'Orphaned: job lost (server restart or persist failure)',
                       error_category = 'orphan_cleanup',
                       completed_at = ?
                   WHERE id = ?""",
                (now_iso, jid),
            )
            reaped += 1
    if reaped:
        logger.info("Reaped %d orphaned screening_jobs row(s) for project %d", reaped, project_id)
    return reaped


# ── Background task ──


async def _run_screening_job(job_id: str, req: ScreeningRequest, openai_api_key: str | None = None):
    """Background task that runs the screening pipeline with timing capture."""
    # Set OpenAI key for this task if provided via request header
    if openai_api_key:
        os.environ["OPENAI_API_KEY"] = openai_api_key

    job = _active_jobs[job_id]
    job["status"] = "running"
    try:
        _save_job(job)
    except Exception:
        logger.warning("Failed to persist running status for screening job %s", job_id)

    start_time = time.monotonic()
    stage_timings: dict[str, float] = {}
    current_stage_start = start_time

    try:
        papers_df = pd.DataFrame(req.papers)

        def progress_cb(current, total, stage):
            nonlocal current_stage_start
            now = time.monotonic()
            # Record timing for the previous stage if stage changed
            prev_stage = job["stage"]
            if prev_stage and prev_stage != stage:
                stage_timings[prev_stage] = round((now - current_stage_start) * 1000)
                current_stage_start = now
            job["progress"] = current / total if total > 0 else 0
            job["stage"] = stage

        if req.mock:
            from crystallise.screening.mock import MockAIService

            mock = MockAIService()
            results_df, clusters = await mock.screen_papers(
                papers_df=papers_df,
                criteria=req.criteria,
                questions=req.questions,
                progress_callback=progress_cb,
                repetitions=req.repetitions,
                clusters_type=req.clusters_type,
                threshold=req.threshold,
            )
        else:
            from crystallise.screening.pipeline import screen_papers

            results_df, clusters = await screen_papers(
                papers_df=papers_df,
                criteria=req.criteria,
                questions=req.questions,
                model_name=req.model,
                repetitions=req.repetitions,
                threshold=req.threshold,
                clusters_type=req.clusters_type,
                progress_callback=progress_cb,
            )

        end_time = time.monotonic()
        # Record final stage timing
        if job["stage"]:
            stage_timings[job["stage"]] = round((end_time - current_stage_start) * 1000)

        now_utc = datetime.now(timezone.utc).isoformat()
        job["status"] = "completed"
        job["progress"] = 1.0
        job["stage"] = "Complete"
        job["results"] = results_df.to_dict("records") if results_df is not None else []
        job["clusters"] = clusters or []
        job["stage_timings"] = stage_timings
        job["duration_ms"] = round((end_time - start_time) * 1000)
        job["completed_at"] = now_utc

    except Exception as e:
        from crystallise.llm.errors import RETRYABLE_CATEGORIES, classify_openai_error

        end_time = time.monotonic()
        category = classify_openai_error(e)
        job["status"] = "failed"
        job["error"] = str(e)
        job["error_category"] = category.value
        job["error_retryable"] = category in RETRYABLE_CATEGORIES
        job["duration_ms"] = round((end_time - start_time) * 1000)
        job["completed_at"] = datetime.now(timezone.utc).isoformat()
    finally:
        try:
            _save_job(job)
        except Exception:
            logger.error("Failed to persist screening job %s to DB", job_id, exc_info=True)


# ── Cost estimation (empirical, mirrors indexer pattern) ──

# Screening uses ~200 tokens per paper for input (title + abstract + criteria context)
# plus ~30 tokens per criterion. Output is ~10 tokens per paper per repetition (score).
# Reasoning stage: ~300 input + ~100 output per paper.
_AVG_INPUT_TOKENS_PER_PAPER = 200
_AVG_CRITERIA_TOKENS = 30
_AVG_SCORE_OUTPUT_TOKENS = 10  # per paper per repetition
_AVG_REASONING_INPUT_TOKENS = 300  # per paper
_AVG_REASONING_OUTPUT_TOKENS = 100  # per paper


@router.post("/estimate", response_model=ScreeningCostEstimateResponse)
def estimate_screening_cost(req: ScreeningCostEstimateRequest):
    """Estimate cost for a screening run before starting."""
    # Stage 1 (Labelling): each paper scored N times
    scoring_input = (
        (_AVG_INPUT_TOKENS_PER_PAPER + req.criteria_count * _AVG_CRITERIA_TOKENS) * req.papers_count * req.repetitions
    )
    scoring_output = _AVG_SCORE_OUTPUT_TOKENS * req.papers_count * req.repetitions

    # Stage 2 (Reasoning): each paper gets one reasoning call
    reasoning_input = _AVG_REASONING_INPUT_TOKENS * req.papers_count
    reasoning_output = _AVG_REASONING_OUTPUT_TOKENS * req.papers_count

    total_input = scoring_input + reasoning_input
    total_output = scoring_output + reasoning_output
    cost = estimate_cost(req.model, total_input, total_output)

    return ScreeningCostEstimateResponse(
        estimated_input_tokens=total_input,
        estimated_output_tokens=total_output,
        estimated_cost_usd=round(cost, 6),
        model=req.model,
        papers_count=req.papers_count,
        repetitions=req.repetitions,
    )


# ── CRUD endpoints ──


@router.get("/active-job")
def get_active_screening_job(project_id: int):
    """Get the currently running screening job for a project, if any."""
    job = _get_active_job_for_project(project_id)
    if job is None:
        return None
    return _job_to_response(job.get("id", ""), job)


@router.post("/jobs", response_model=ScreeningJobResponse)
async def create_screening_job(req: ScreeningRequest, background_tasks: BackgroundTasks, request: Request):
    """Start an async screening job."""
    # Prevent duplicate jobs for the same project
    if req.project_id:
        # Reap orphans (status='running' in DB without an in-memory entry)
        # before the duplicate check, so a previous run's stale row doesn't
        # 409-block the user. See `_cleanup_stale_running_jobs`.
        _cleanup_stale_running_jobs(req.project_id)
        existing = _get_active_job_for_project(req.project_id)
        if existing:
            raise HTTPException(
                status_code=409,
                detail=f"A screening job is already {existing['status']} for this project."
                f" Job ID: {existing.get('id', 'unknown')}",
            )

    # Cost ceiling check
    if req.max_estimated_cost_usd is not None and not req.mock:
        est = estimate_screening_cost(
            ScreeningCostEstimateRequest(
                model=req.model,
                papers_count=len(req.papers),
                repetitions=req.repetitions,
                criteria_count=len(req.criteria),
            )
        )
        if est.estimated_cost_usd > req.max_estimated_cost_usd:
            raise HTTPException(
                status_code=400,
                detail=(
                    f"Estimated cost ${est.estimated_cost_usd:.4f} exceeds "
                    f"limit ${req.max_estimated_cost_usd:.4f}. "
                    f"Reduce papers ({len(req.papers)}) or repetitions ({req.repetitions})."
                ),
            )
        estimated_cost = est.estimated_cost_usd
    else:
        # Still calculate estimate for diagnostics even without ceiling
        est = estimate_screening_cost(
            ScreeningCostEstimateRequest(
                model=req.model,
                papers_count=len(req.papers),
                repetitions=req.repetitions,
                criteria_count=len(req.criteria),
            )
        )
        estimated_cost = est.estimated_cost_usd

    job_id = str(uuid.uuid4())
    now_utc = datetime.now(timezone.utc).isoformat()

    job = {
        "id": job_id,
        "project_id": req.project_id,
        "status": "pending",
        "progress": 0.0,
        "stage": "",
        "config": {
            "model": req.model,
            "repetitions": req.repetitions,
            "threshold": req.threshold,
            "clusters_type": req.clusters_type,
            "mock": req.mock,
            "papers_count": len(req.papers),
            "criteria_count": len(req.criteria),
        },
        "results": None,
        "clusters": None,
        "stage_timings": {},
        "duration_ms": None,
        "estimated_cost_usd": estimated_cost,
        "error": None,
        "error_category": None,
        "error_retryable": None,
        "model_version": req.model,
        "created_at": now_utc,
        "completed_at": None,
    }

    _active_jobs[job_id] = job
    _save_job(job)

    # Capture OpenAI key from request header for the background task
    openai_key = request.headers.get("x-openai-api-key")
    background_tasks.add_task(_run_screening_job, job_id, req, openai_key)
    return ScreeningJobResponse(job_id=job_id, status="pending")


@router.get("/jobs", response_model=list[ScreeningJobListItem])
async def list_screening_jobs(limit: int = 50):
    """List recent screening jobs."""
    rows = _list_jobs(limit=limit)

    return [
        ScreeningJobListItem(
            job_id=r["id"],
            status=r["status"],
            progress=r["progress"],
            stage=r.get("stage", ""),
            papers_count=r.get("config", {}).get("papers_count", 0),
            model=r.get("config", {}).get("model", ""),
            project_id=r.get("project_id"),
            project_name=None,
            duration_ms=r.get("duration_ms"),
            estimated_cost_usd=r.get("estimated_cost_usd"),
            created_at=r.get("created_at"),
            completed_at=r.get("completed_at"),
        )
        for r in rows
    ]


@router.get("/jobs/{job_id}", response_model=ScreeningResultResponse)
async def get_screening_job(job_id: str):
    """Get screening job status and results."""
    # Check in-memory cache first (for active polling)
    if job_id in _active_jobs:
        job = _active_jobs[job_id]
        # Clean up from active cache once terminal
        if job["status"] in ("completed", "failed"):
            _active_jobs.pop(job_id, None)
        return _job_to_response(job_id, job)

    # Fall back to DB. Pure read: returning state, never writing.
    # Orphaned 'running' rows (server restart, persist failure) are reaped on
    # POST /screening/jobs (`_cleanup_stale_running_jobs`) — never here. The
    # pre-fix auto-fail-on-read flipped rows under a navigating user's
    # session, producing spurious "Job lost due to server restart"
    # attribution. See crystallise-master commit ef47ce3.
    job = _load_job(job_id)
    if job is None:
        raise HTTPException(status_code=404, detail="Job not found")

    return _job_to_response(job_id, job)


def _job_to_response(job_id: str, job: dict[str, Any]) -> ScreeningResultResponse:
    return ScreeningResultResponse(
        job_id=job_id,
        status=job["status"],
        progress=job["progress"],
        stage=job.get("stage", ""),
        results=job.get("results"),
        clusters=job.get("clusters"),
        error=job.get("error"),
        error_category=job.get("error_category"),
        error_retryable=job.get("error_retryable"),
        model_version=job.get("model_version"),
        config=job.get("config"),
        stage_timings=job.get("stage_timings"),
        duration_ms=job.get("duration_ms"),
        estimated_cost_usd=job.get("estimated_cost_usd"),
        created_at=job.get("created_at"),
        completed_at=job.get("completed_at"),
    )
