"""AutoIndexer API endpoints."""

from __future__ import annotations

import asyncio
import json
import logging
import time
import uuid
from datetime import datetime, timezone
from typing import Any

from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request

from api.dependencies import get_openai_client
from api.schemas.indexer import (
    CostEstimateRequest,
    CostEstimateResponse,
    GroupTagsRequest,
    GroupTagsResponse,
    IndexerField,
    IndexerJobListItem,
    IndexerJobResponse,
    IndexerJobStatusResponse,
    IndexerRequest,
    IndexerResult,
    RefineFieldsRequest,
    RefineFieldsResponse,
    SuggestFieldsRequest,
    SuggestFieldsResponse,
)
from crystallise.indexer.pipeline import (
    DEFAULT_SYSTEM_PROMPT,
    DEFAULT_USER_PROMPT,
    TOOL_NAME,
    process_record,
)
from crystallise.indexer.schema_builder import create_function_schema
from crystallise.llm.cost import estimate_cost, tally_usage

logger = logging.getLogger(__name__)

router = APIRouter()

# ── DB helpers ──

_CREATE_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS indexer_jobs (
    id TEXT PRIMARY KEY,
    project_id INTEGER,
    status TEXT NOT NULL DEFAULT 'pending',
    progress REAL NOT NULL DEFAULT 0.0,
    config TEXT NOT NULL DEFAULT '{}',
    results TEXT,
    errors TEXT NOT NULL DEFAULT '[]',
    usage TEXT NOT NULL DEFAULT '{}',
    error TEXT,
    error_category TEXT,
    error_retryable INTEGER,
    model_version TEXT,
    created_at TEXT NOT NULL,
    completed_at TEXT,
    duration_ms INTEGER,
    estimated_cost_usd REAL
)
"""

_table_created = False
_db_available = True


def _get_db():
    from crystallise.db.backend import get_backend

    return get_backend()


def _ensure_table():
    global _table_created, _db_available
    if _table_created or not _db_available:
        return
    try:
        db = _get_db()
        with db.get_connection() as conn:
            if not db.table_exists(conn, "indexer_jobs"):
                db.executescript(conn, _CREATE_TABLE_SQL)
        _table_created = True
    except Exception:
        logger.warning(
            "Database init failed for indexer_jobs; running without persistence",
            exc_info=True,
        )
        _db_available = False


def _save_job(job: dict[str, Any]):
    """Upsert an indexer job to the database. No-op if DB unavailable."""
    _ensure_table()
    if not _db_available:
        return
    db = _get_db()
    params = (
        job["id"],
        job.get("project_id"),
        job["status"],
        job["progress"],
        json.dumps(job.get("config", {})),
        json.dumps(job["results"]) if job.get("results") is not None else None,
        json.dumps(job.get("errors", [])),
        json.dumps(job.get("usage", {})),
        job.get("error"),
        job.get("error_category"),
        1 if job.get("error_retryable") else (0 if job.get("error_retryable") is not None else None),
        job.get("model_version"),
        job["created_at"],
        job.get("completed_at"),
        job.get("duration_ms"),
        job.get("estimated_cost_usd"),
    )
    with db.get_connection() as conn:
        cursor = db.execute(
            conn,
            """UPDATE indexer_jobs SET
                project_id=?, status=?, progress=?, config=?, results=?,
                errors=?, usage=?, error=?, error_category=?,
                error_retryable=?, model_version=?,
                created_at=?, completed_at=?, duration_ms=?, estimated_cost_usd=?
               WHERE id=?""",
            params[1:] + (params[0],),
        )
        if cursor.rowcount == 0:
            db.execute(
                conn,
                """INSERT INTO indexer_jobs
                   (id, project_id, status, progress, config, results, errors, usage,
                    error, error_category, error_retryable, model_version,
                    created_at, completed_at, duration_ms, estimated_cost_usd)
                   VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
                params,
            )


def _load_job(job_id: str) -> dict[str, Any] | None:
    """Load an indexer job from the database. Returns None if DB unavailable."""
    _ensure_table()
    if not _db_available:
        return None
    db = _get_db()
    with db.get_connection() as conn:
        cursor = db.execute(conn, "SELECT * FROM indexer_jobs WHERE id = ?", (job_id,))
        row = cursor.fetchone()
        if row is None:
            return None
        return _row_to_dict(row)


def _list_jobs(limit: int = 50) -> list[dict[str, Any]]:
    """List indexer jobs ordered by creation time (newest first)."""
    _ensure_table()
    if not _db_available:
        return []
    db = _get_db()
    with db.get_connection() as conn:
        cursor = db.execute(
            conn,
            "SELECT * FROM indexer_jobs ORDER BY created_at DESC LIMIT ?",
            (limit,),
        )
        return [_row_to_dict(row) for row in cursor.fetchall()]


def _row_to_dict(row) -> dict[str, Any]:
    """Convert a DB row to a job dict."""
    r = dict(row)
    for json_field in ("config", "results", "errors", "usage"):
        if r.get(json_field) and isinstance(r[json_field], str):
            r[json_field] = json.loads(r[json_field])
    # Convert error_retryable from int to bool
    if r.get("error_retryable") is not None:
        r["error_retryable"] = bool(r["error_retryable"])
    return r


# ── In-memory cache for active jobs (fast polling) ──

_active_jobs: dict[str, dict[str, Any]] = {}


def _get_active_job_for_project(project_id: int) -> dict[str, Any] | None:
    """Return the currently active (pending/running) indexer job for a project, if any.

    Jobs found in DB but not in _active_jobs are orphaned and marked as failed.
    """
    for job in _active_jobs.values():
        if job.get("project_id") == project_id and job["status"] in ("pending", "running"):
            return job
    _ensure_table()
    if not _db_available:
        return None
    db = _get_db()
    with db.get_connection() as conn:
        cursor = db.execute(
            conn,
            "SELECT * FROM indexer_jobs WHERE project_id = ? AND status IN ('pending', 'running')"
            " ORDER BY created_at DESC LIMIT 1",
            (project_id,),
        )
        row = cursor.fetchone()
        if row is None:
            return None
        job = _row_to_dict(row)
        job["status"] = "failed"
        job["error"] = "Job lost due to server restart"
        job["completed_at"] = datetime.now(timezone.utc).isoformat()
        _save_job(job)
        logger.info("Marked orphaned indexer job %s as failed", job["id"])
        return None


def _build_system_prompt(req: IndexerRequest) -> str:
    """Build system prompt, injecting project context and field guidance if provided."""
    from crystallise.prompts.indexer import (
        build_field_injection_block,
        build_indexer_system_prompt,
    )

    base = req.system_prompt or DEFAULT_SYSTEM_PROMPT
    fields_data = [f.model_dump() for f in req.fields]
    field_block = build_field_injection_block(fields_data)
    if field_block:
        base = base + "\n\n" + field_block

    if req.project_context:
        return build_indexer_system_prompt(
            base,
            description=req.project_context.description,
            research_questions=req.project_context.research_questions,
        )
    return base


def _slice_records(req: IndexerRequest) -> list[dict[str, str]]:
    """Return records sliced according to processing mode."""
    records = req.records
    if req.mode == "test":
        return records[: req.test_size]
    elif req.mode == "sample":
        return records[: req.sample_size]
    return records


async def _process_batch(
    req: IndexerRequest,
    records: list[dict[str, str]],
    *,
    on_progress: Any | None = None,
) -> tuple[list[dict], list[str], list[dict]]:
    """Process records concurrently using async semaphore."""
    fields = [f.model_dump() for f in req.fields]
    fn_schema = create_function_schema(
        TOOL_NAME,
        "Extract structured data with evidence and confidence",
        fields,
    )

    tools = [{"type": "function", "function": fn_schema}]
    system_msg = _build_system_prompt(req)
    user_msg = req.user_prompt or DEFAULT_USER_PROMPT

    semaphore = asyncio.Semaphore(req.max_workers)
    results: list[dict | None] = []
    errors: list[str] = []
    usages: list[dict] = []

    async def _process_one(record: dict) -> None:
        async with semaphore:
            result, usage, error = await process_record(
                record=record,
                tools=tools,
                system_message=system_msg,
                user_message=user_msg,
                model=req.model,
            )
            results.append(result)
            if usage:
                usages.append(usage)
            if error:
                errors.append(error)
            if on_progress:
                on_progress(len(results), len(records))

    await asyncio.gather(*[_process_one(r) for r in records])
    return [r for r in results if r is not None], errors, usages


@router.post("/run", response_model=IndexerResult)
async def run_indexer(req: IndexerRequest):
    """Extract structured fields from title/abstract records."""
    records = _slice_records(req)
    results, errors, usages = await _process_batch(req, records)
    usage = tally_usage(req.model, usages) if usages else {}

    return IndexerResult(
        results=results,
        errors=errors,
        usage=usage,
        model_version=req.model,
    )


# ── Async job pattern ──


async def _run_indexer_job(job_id: str, req: IndexerRequest, openai_api_key: str | None = None):
    """Background task that runs the indexer pipeline."""
    # Set OpenAI key for this task if provided via request header
    if openai_api_key:
        import os

        os.environ["OPENAI_API_KEY"] = openai_api_key

    job = _active_jobs[job_id]
    job["status"] = "running"
    try:
        _save_job(job)
    except Exception:
        logger.warning("Failed to persist running status for indexer job %s", job_id)

    start_time = time.monotonic()

    try:
        records = _slice_records(req)

        def progress_cb(current, total):
            job["progress"] = current / total if total > 0 else 0

        results, errors, usages = await _process_batch(
            req,
            records,
            on_progress=progress_cb,
        )
        usage = tally_usage(req.model, usages) if usages else {}

        end_time = time.monotonic()
        now_utc = datetime.now(timezone.utc).isoformat()

        job["status"] = "completed_with_errors" if errors else "completed"
        job["progress"] = 1.0
        job["results"] = results
        job["errors"] = errors
        job["usage"] = usage
        job["duration_ms"] = round((end_time - start_time) * 1000)
        job["completed_at"] = now_utc

    except Exception as e:
        from crystallise.llm.errors import RETRYABLE_CATEGORIES, classify_openai_error

        end_time = time.monotonic()
        category = classify_openai_error(e)
        job["status"] = "failed"
        job["error"] = str(e)
        job["error_category"] = category.value
        job["error_retryable"] = category in RETRYABLE_CATEGORIES
        job["duration_ms"] = round((end_time - start_time) * 1000)
        job["completed_at"] = datetime.now(timezone.utc).isoformat()
    finally:
        try:
            _save_job(job)
        except Exception:
            logger.error("Failed to persist indexer job %s to DB", job_id, exc_info=True)


@router.get("/active-job")
def get_active_indexer_job(project_id: int):
    """Get the currently running indexer job for a project, if any."""
    job = _get_active_job_for_project(project_id)
    if job is None:
        return None
    return _job_to_response(job.get("id", ""), job)


@router.post("/jobs", response_model=IndexerJobResponse)
async def create_indexer_job(
    req: IndexerRequest,
    background_tasks: BackgroundTasks,
    request: Request,
):
    """Start an async indexer job."""
    # Prevent duplicate jobs for the same project
    if req.project_id:
        existing = _get_active_job_for_project(req.project_id)
        if existing:
            raise HTTPException(
                status_code=409,
                detail=f"An indexer job is already {existing['status']} for this project."
                f" Job ID: {existing.get('id', 'unknown')}",
            )

    # Calculate estimated cost for diagnostics
    n_fields = len(req.fields)
    input_per_record = _AVG_INPUT_TOKENS_PER_RECORD + (n_fields * _AVG_INPUT_TOKENS_PER_FIELD)
    output_per_record = n_fields * _AVG_OUTPUT_TOKENS_PER_FIELD
    records = _slice_records(req)
    total_input = input_per_record * len(records)
    total_output = output_per_record * len(records)
    estimated_cost = round(estimate_cost(req.model, total_input, total_output), 6)

    job_id = str(uuid.uuid4())
    now_utc = datetime.now(timezone.utc).isoformat()

    job = {
        "id": job_id,
        "project_id": req.project_id,
        "status": "pending",
        "progress": 0.0,
        "config": {
            "model": req.model,
            "record_count": len(records),
            "fields_count": n_fields,
            "mode": req.mode,
            "mock": False,
        },
        "results": None,
        "errors": [],
        "usage": {},
        "error": None,
        "model_version": req.model,
        "estimated_cost_usd": estimated_cost,
        "duration_ms": None,
        "created_at": now_utc,
        "completed_at": None,
    }

    _active_jobs[job_id] = job
    _save_job(job)

    openai_key = request.headers.get("x-openai-api-key")
    background_tasks.add_task(_run_indexer_job, job_id, req, openai_key)
    return IndexerJobResponse(job_id=job_id, status="pending")


@router.get("/jobs/{job_id}", response_model=IndexerJobStatusResponse)
def get_indexer_job(job_id: str):
    """Get indexer job status and results."""
    # Check in-memory cache first (for active polling)
    if job_id in _active_jobs:
        job = _active_jobs[job_id]
        # Clean up from active cache once terminal
        if job["status"] in ("completed", "completed_with_errors", "failed"):
            _active_jobs.pop(job_id, None)
        return _job_to_response(job_id, job)

    # Fall back to DB
    job = _load_job(job_id)
    if job is None:
        raise HTTPException(status_code=404, detail="Job not found")
    return _job_to_response(job_id, job)


def _job_to_response(job_id: str, job: dict[str, Any]) -> IndexerJobStatusResponse:
    return IndexerJobStatusResponse(
        job_id=job_id,
        status=job["status"],
        progress=job["progress"],
        partial_results=job.get("results"),
        errors=job.get("errors"),
        usage=job.get("usage"),
        error=job.get("error"),
        config=job.get("config"),
        model_version=job.get("model_version"),
        duration_ms=job.get("duration_ms"),
        estimated_cost_usd=job.get("estimated_cost_usd"),
        created_at=job.get("created_at"),
        completed_at=job.get("completed_at"),
        error_category=job.get("error_category"),
        error_retryable=job.get("error_retryable"),
    )


@router.get("/jobs", response_model=list[IndexerJobListItem])
def list_indexer_jobs(limit: int = 50):
    """List recent indexer jobs."""
    rows = _list_jobs(limit=limit)
    return [
        IndexerJobListItem(
            job_id=r["id"],
            status=r["status"],
            progress=r["progress"],
            model=r.get("config", {}).get("model", ""),
            record_count=r.get("config", {}).get("record_count", 0),
            duration_ms=r.get("duration_ms"),
            estimated_cost_usd=r.get("estimated_cost_usd"),
            created_at=r.get("created_at"),
            completed_at=r.get("completed_at"),
        )
        for r in rows
    ]


# ── Cost estimation ──

# Average tokens per field per record (empirical estimate)
_AVG_INPUT_TOKENS_PER_RECORD = 250  # title + abstract
_AVG_INPUT_TOKENS_PER_FIELD = 30  # field schema overhead
_AVG_OUTPUT_TOKENS_PER_FIELD = 25  # extraction output per field


@router.post("/estimate", response_model=CostEstimateResponse)
def estimate_indexer_cost(req: CostEstimateRequest):
    """Estimate cost for an indexer run."""
    n_fields = len(req.fields)
    # System + user prompt overhead + record content + field schemas
    input_per_record = _AVG_INPUT_TOKENS_PER_RECORD + (n_fields * _AVG_INPUT_TOKENS_PER_FIELD)
    output_per_record = n_fields * _AVG_OUTPUT_TOKENS_PER_FIELD

    total_input = input_per_record * req.record_count
    total_output = output_per_record * req.record_count
    cost = estimate_cost(req.model, total_input, total_output)

    return CostEstimateResponse(
        estimated_input_tokens=total_input,
        estimated_output_tokens=total_output,
        estimated_cost_usd=round(cost, 6),
    )


# ── AI Field Refinement ──


@router.post("/refine-fields", response_model=RefineFieldsResponse)
def refine_fields(req: RefineFieldsRequest, client=Depends(get_openai_client)):
    """AI review of field definitions with improvement suggestions."""
    from crystallise.indexer.refinement import refine_fields as _refine

    return _refine(client=client, req=req)


# ── Tag Grouping ──


@router.post("/group-tags", response_model=GroupTagsResponse)
def group_tags(req: GroupTagsRequest, client=Depends(get_openai_client)):
    """AI-assisted grouping of extracted tag values."""
    from crystallise.indexer.grouping import group_tags as _group

    return _group(client=client, req=req)


# ── Field Suggestion ──


@router.post("/suggest-fields", response_model=SuggestFieldsResponse)
async def suggest_fields(req: SuggestFieldsRequest):
    """AI-powered field suggestion based on project context."""
    if req.mock:
        if req.existing_fields:
            return SuggestFieldsResponse(
                fields=[
                    IndexerField(
                        name=name,
                        description=f"Mock description for '{name}' — what this field captures in the context of the systematic review.",
                        data_type_primary="string",
                        examples=["example1", "example2"],
                    )
                    for name in req.existing_fields
                ]
            )
        return SuggestFieldsResponse(
            fields=[
                IndexerField(
                    name="study_design",
                    description="Type of study",
                    data_type_primary="string",
                    examples=["RCT", "cohort", "case-control"],
                ),
                IndexerField(
                    name="sample_size",
                    description="Number of participants",
                    data_type_primary="number",
                    examples=["100", "500"],
                ),
                IndexerField(
                    name="country",
                    description="Country where study conducted",
                    data_type_primary="string",
                    examples=["USA", "UK"],
                ),
                IndexerField(
                    name="interventions",
                    description="Interventions studied",
                    data_type_primary="array-string",
                    examples=["drug A", "placebo"],
                ),
                IndexerField(
                    name="outcomes",
                    description="Primary outcomes measured",
                    data_type_primary="array-string",
                    examples=["mortality", "survival"],
                ),
            ]
        )

    from crystallise.indexer.field_suggestion import suggest_fields as _suggest

    try:
        raw_fields, raw_warnings = await _suggest(
            project_description=req.project_context.description if req.project_context else "",
            research_questions=req.project_context.research_questions if req.project_context else [],
            pico=req.pico,
            sample_records=req.sample_records,
            existing_fields=req.existing_fields,
            model=req.model,
        )
        for f in raw_fields:
            if "examples" in f and isinstance(f["examples"], list):
                flat: list[str] = []
                for e in f["examples"]:
                    if isinstance(e, list):
                        flat.extend(str(v) for v in e if v)
                    elif e:
                        flat.append(str(e))
                f["examples"] = flat
        fields = [IndexerField(**f) for f in raw_fields]
        from api.schemas.indexer import ExtractionWarning

        warnings = []
        for w in raw_warnings:
            if isinstance(w, dict) and w.get("field"):
                warnings.append(ExtractionWarning(**w))
        return SuggestFieldsResponse(fields=fields, warnings=warnings)
    except Exception as e:
        from api.utils import classify_and_raise

        classify_and_raise(e)
