#!/usr/bin/env python3
"""End-to-end smoke test for every AI endpoint of netready-mvp.

Runs a two-pass execution against the live FastAPI service:
  1. Mock pass — every endpoint that supports `mock: true` (full corpus, free).
  2. Live pass — all endpoints, downsampled (default 20 papers / 20 records).

Outputs:
  - stdout pass/fail table with elapsed ms per endpoint
  - scripts/output/<ISO timestamp>/<endpoint-slug>.json  (full responses)
  - exit 0 on success, non-zero if any endpoint fails

Usage:
  python scripts/smoke_ai_endpoints.py                          # two-pass (live needs OPENAI_API_KEY)
  python scripts/smoke_ai_endpoints.py --mock-only              # CI-friendly, no LLM calls
  python scripts/smoke_ai_endpoints.py --only screening,criteria
  python scripts/smoke_ai_endpoints.py --live-sample 5 --base-url http://localhost:8005
"""
from __future__ import annotations

import argparse
import json
import os
import sys
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Callable

import httpx

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from scripts._fixtures import Fixtures, load_fixtures, sample as sample_fixtures  # noqa: E402


SCREENING_POLL_TIMEOUT_S = 240
INDEXER_POLL_TIMEOUT_S = 240
POLL_INTERVAL_S = 1.0


# ── result tracking ───────────────────────────────────────────────────────────


@dataclass
class EndpointResult:
    name: str
    pass_: bool
    elapsed_ms: int
    detail: str = ""
    response: Any = None


@dataclass
class RunReport:
    pass_: list[EndpointResult] = field(default_factory=list)
    fail: list[EndpointResult] = field(default_factory=list)

    def add(self, r: EndpointResult) -> None:
        (self.pass_ if r.pass_ else self.fail).append(r)


# ── HTTP helpers ──────────────────────────────────────────────────────────────


class ApiClient:
    def __init__(self, base_url: str, api_key: str, openai_key: str | None, *, output_dir: Path | None):
        headers = {"X-API-Key": api_key, "Content-Type": "application/json"}
        if openai_key:
            headers["X-OpenAI-API-Key"] = openai_key
        self.client = httpx.Client(base_url=base_url, headers=headers, timeout=120.0)
        self.output_dir = output_dir

    def close(self) -> None:
        self.client.close()

    def _dump(self, slug: str, response: Any) -> None:
        if self.output_dir is None:
            return
        path = self.output_dir / f"{slug}.json"
        path.write_text(json.dumps(response, indent=2, default=str))

    def post(self, path: str, payload: dict, slug: str) -> dict:
        resp = self.client.post(path, json=payload)
        body = self._safe_json(resp)
        self._dump(slug, {"status_code": resp.status_code, "request": _truncate(payload), "response": body})
        if resp.status_code >= 400:
            raise httpx.HTTPStatusError(
                f"{resp.status_code}: {body if isinstance(body, str) else json.dumps(body)[:300]}",
                request=resp.request,
                response=resp,
            )
        return body

    def get(self, path: str) -> dict:
        resp = self.client.get(path)
        body = self._safe_json(resp)
        if resp.status_code >= 400:
            raise httpx.HTTPStatusError(
                f"{resp.status_code}: {body if isinstance(body, str) else json.dumps(body)[:300]}",
                request=resp.request,
                response=resp,
            )
        return body

    @staticmethod
    def _safe_json(resp: httpx.Response) -> Any:
        try:
            return resp.json()
        except Exception:
            return resp.text


def _truncate(payload: dict, list_cap: int = 5) -> dict:
    """Truncate large list fields in a logged request payload to keep dumps readable."""
    out = dict(payload)
    for key in ("papers", "records", "values", "criteria"):
        if isinstance(out.get(key), list) and len(out[key]) > list_cap:
            out[key] = out[key][:list_cap] + [f"…{len(payload[key]) - list_cap} more"]
    return out


def _poll_job(
    client: ApiClient,
    status_path: str,
    slug: str,
    timeout_s: int,
) -> dict:
    """Poll a job status endpoint until terminal state or timeout, printing progress."""
    deadline = time.monotonic() + timeout_s
    last: dict = {}
    last_signature: tuple = ()
    started = time.monotonic()
    while time.monotonic() < deadline:
        last = client.get(status_path)
        status = last.get("status")
        sig = (status, last.get("stage"), round(float(last.get("progress") or 0), 2))
        if sig != last_signature:
            elapsed = int(time.monotonic() - started)
            pct = int(sig[2] * 100)
            stage = sig[1] or ""
            print(f"      [{elapsed:>3}s] status={status} progress={pct}% stage={stage!r}")
            last_signature = sig
        if status in ("completed", "failed", "completed_with_errors"):
            client._dump(slug, last)
            return last
        time.sleep(POLL_INTERVAL_S)
    client._dump(slug, last)
    raise TimeoutError(f"Job did not finish within {timeout_s}s; last status: {last.get('status')!r}")


# ── per-endpoint runners ──────────────────────────────────────────────────────


def _time(fn: Callable[[], dict]) -> tuple[dict, int]:
    t0 = time.monotonic()
    out = fn()
    return out, int((time.monotonic() - t0) * 1000)


def _record(name: str, fn: Callable[[], dict], report: RunReport) -> None:
    try:
        body, ms = _time(fn)
        report.add(EndpointResult(name, True, ms, response=body))
        print(f"  ✓ {name:<40} {ms:>6} ms")
    except Exception as e:
        ms = 0
        report.add(EndpointResult(name, False, ms, detail=str(e)[:200]))
        print(f"  ✗ {name:<40} FAILED — {str(e)[:160]}")


# ── endpoint families ─────────────────────────────────────────────────────────


def run_screening(client: ApiClient, fx: Fixtures, *, mock: bool, report: RunReport) -> None:
    label = "mock" if mock else "live"
    print(f"\n[{label}] screening")

    _record(
        "POST /v1/screening/estimate",
        lambda: client.post(
            "/v1/screening/estimate",
            {
                "model": "gpt-5-nano",
                "papers_count": len(fx.papers),
                "repetitions": 3,
                "criteria_count": len(fx.criteria),
            },
            slug=f"{label}_screening_estimate",
        ),
        report,
    )

    def _screening_job() -> dict:
        job = client.post(
            "/v1/screening/jobs",
            {
                "papers": fx.papers,
                "criteria": fx.criteria,
                "questions": fx.questions,
                "model": "gpt-5-nano",
                "repetitions": 2,  # keep cost minimal for live
                "threshold": 1.0,
                "mock": mock,
            },
            slug=f"{label}_screening_jobs_create",
        )
        job_id = job["job_id"]
        final = _poll_job(
            client,
            f"/v1/screening/jobs/{job_id}",
            slug=f"{label}_screening_jobs_final",
            timeout_s=SCREENING_POLL_TIMEOUT_S,
        )
        if final.get("status") != "completed":
            raise RuntimeError(f"job ended in status {final.get('status')!r}: {final.get('error')!r}")
        return final

    _record("POST /v1/screening/jobs (+ poll)", _screening_job, report)


def run_criteria(client: ApiClient, fx: Fixtures, *, mock: bool, report: RunReport) -> None:
    label = "mock" if mock else "live"
    print(f"\n[{label}] criteria")
    common = {"mock": mock}

    _record(
        "POST /v1/criteria/picos",
        lambda: client.post(
            "/v1/criteria/picos",
            {"project_description": fx.project_description, "research_questions": fx.questions, **common},
            slug=f"{label}_criteria_picos",
        ),
        report,
    )

    _record(
        "POST /v1/criteria/analyze-question",
        lambda: client.post(
            "/v1/criteria/analyze-question",
            {"research_question": fx.questions[0], **common},
            slug=f"{label}_criteria_analyze_question",
        ),
        report,
    )

    _record(
        "POST /v1/criteria/generate",
        lambda: client.post(
            "/v1/criteria/generate",
            {
                "project_description": fx.project_description,
                "research_questions": fx.questions,
                "criterion_type": "exclude",
                **common,
            },
            slug=f"{label}_criteria_generate",
        ),
        report,
    )

    _record(
        "POST /v1/criteria/refine-context",
        lambda: client.post(
            "/v1/criteria/refine-context",
            {"description": fx.project_description, "research_questions": fx.questions, **common},
            slug=f"{label}_criteria_refine_context",
        ),
        report,
    )

    _record(
        "POST /v1/criteria/refine",
        lambda: client.post(
            "/v1/criteria/refine",
            {
                "current_criteria": fx.criteria,
                "conflicts": [],
                "project_description": fx.project_description,
                **common,
            },
            slug=f"{label}_criteria_refine",
        ),
        report,
    )

    _record(
        "POST /v1/criteria/consolidate",
        lambda: client.post(
            "/v1/criteria/consolidate",
            {
                "criteria": fx.criteria,
                "project_description": fx.project_description,
                "research_questions": fx.questions,
                **common,
            },
            slug=f"{label}_criteria_consolidate",
        ),
        report,
    )


def run_indexer(client: ApiClient, fx: Fixtures, *, mock: bool, report: RunReport) -> None:
    """Indexer endpoints. Only suggest-fields supports mock; rest are live-only."""
    label = "mock" if mock else "live"
    print(f"\n[{label}] indexer")

    project_ctx = {"description": fx.project_description, "research_questions": fx.questions}

    _record(
        "POST /v1/indexer/suggest-fields",
        lambda: client.post(
            "/v1/indexer/suggest-fields",
            {
                "project_context": project_ctx,
                "sample_records": fx.indexer_records[:3],
                "existing_fields": [],
                "mock": mock,
            },
            slug=f"{label}_indexer_suggest_fields",
        ),
        report,
    )

    if mock:
        # The remaining indexer endpoints don't accept mock=true; defer to live pass.
        print("  (skipping /run, /jobs, /refine-fields, /group-tags, /estimate — no mock support)")
        return

    _record(
        "POST /v1/indexer/estimate",
        lambda: client.post(
            "/v1/indexer/estimate",
            {"model": "gpt-5-mini", "fields": fx.indexer_fields, "record_count": len(fx.indexer_records)},
            slug="live_indexer_estimate",
        ),
        report,
    )

    _record(
        "POST /v1/indexer/run",
        lambda: client.post(
            "/v1/indexer/run",
            {
                "model": "gpt-5-mini",
                "records": fx.indexer_records,
                "fields": fx.indexer_fields,
                "project_context": project_ctx,
                "mode": "test",
                "test_size": min(3, len(fx.indexer_records)),
                "max_workers": 2,
            },
            slug="live_indexer_run",
        ),
        report,
    )

    def _indexer_job() -> dict:
        job = client.post(
            "/v1/indexer/jobs",
            {
                "model": "gpt-5-mini",
                "records": fx.indexer_records,
                "fields": fx.indexer_fields,
                "project_context": project_ctx,
                "mode": "test",
                "test_size": min(3, len(fx.indexer_records)),
                "max_workers": 2,
            },
            slug="live_indexer_jobs_create",
        )
        final = _poll_job(
            client,
            f"/v1/indexer/jobs/{job['job_id']}",
            slug="live_indexer_jobs_final",
            timeout_s=INDEXER_POLL_TIMEOUT_S,
        )
        if final.get("status") not in ("completed", "completed_with_errors"):
            raise RuntimeError(f"job ended in status {final.get('status')!r}: {final.get('error')!r}")
        return final

    _record("POST /v1/indexer/jobs (+ poll)", _indexer_job, report)

    _record(
        "POST /v1/indexer/refine-fields",
        lambda: client.post(
            "/v1/indexer/refine-fields",
            {
                "fields": fx.indexer_fields,
                "project_context": project_ctx,
                "sample_records": fx.indexer_records[:3],
            },
            slug="live_indexer_refine_fields",
        ),
        report,
    )

    if fx.country_values:
        _record(
            "POST /v1/indexer/group-tags",
            lambda: client.post(
                "/v1/indexer/group-tags",
                {"field_name": "country", "values": fx.country_values, "project_context": project_ctx},
                slug="live_indexer_group_tags",
            ),
            report,
        )
    else:
        print("  (skipping /group-tags — no country values in fixtures)")


# ── orchestration ─────────────────────────────────────────────────────────────


FAMILIES = {
    "screening": run_screening,
    "criteria": run_criteria,
    "indexer": run_indexer,
}


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    p = argparse.ArgumentParser(description=__doc__.split("\n")[0])
    p.add_argument("--base-url", default="http://localhost:8005")
    p.add_argument("--api-key", default=os.environ.get("CRYSTALLISE_API_KEY", "dev-key"))
    p.add_argument("--openai-key", default=os.environ.get("OPENAI_API_KEY") or os.environ.get("CRYSTALLISE_OPENAI_API_KEY"))
    p.add_argument("--data-root", default="data")
    p.add_argument("--live-sample", type=int, default=20, help="Papers/records sent during the live pass")
    p.add_argument("--only", default="screening,criteria,indexer", help="Comma-separated subset of families")
    mode = p.add_mutually_exclusive_group()
    mode.add_argument("--mock-only", action="store_true")
    mode.add_argument("--live-only", action="store_true")
    p.add_argument("--continue-on-fail", action="store_true", help="Run live pass even if mock pass had failures")
    p.add_argument("--no-output", action="store_true", help="Skip writing JSON dumps to scripts/output/")
    p.add_argument("--output-root", default="scripts/output")
    return p.parse_args(argv)


def _dump_inputs(output_dir: Path, fx: Fixtures, args: argparse.Namespace) -> None:
    """Snapshot every test input to _inputs.json so the user can inspect what was sent.

    Lists like papers/records are saved in full so nothing is hidden — these become the
    audit trail for the run. Indexer fields are always small (11 entries) so live there too.
    """
    snapshot = {
        "run": {
            "base_url": args.base_url,
            "live_sample": args.live_sample,
            "only": args.only,
            "mock_only": args.mock_only,
            "live_only": args.live_only,
            "data_root": args.data_root,
        },
        "project": {
            "name": fx.project_name,
            "description": fx.project_description,
        },
        "counts": {
            "papers": len(fx.papers),
            "criteria": len(fx.criteria),
            "questions": len(fx.questions),
            "indexer_records": len(fx.indexer_records),
            "indexer_fields": len(fx.indexer_fields),
            "country_values": len(fx.country_values),
        },
        "papers": fx.papers,
        "criteria": fx.criteria,
        "questions": fx.questions,
        "indexer_records": fx.indexer_records,
        "indexer_fields": fx.indexer_fields,
        "country_values": fx.country_values,
    }
    (output_dir / "_inputs.json").write_text(json.dumps(snapshot, indent=2, ensure_ascii=False))
    # Also write indexer_fields standalone for quick inspection — it's the spec the user asked about.
    (output_dir / "_indexer_fields.json").write_text(json.dumps(fx.indexer_fields, indent=2, ensure_ascii=False))


def _print_summary(label: str, report: RunReport) -> None:
    total = len(report.pass_) + len(report.fail)
    print(f"\n── {label} summary: {len(report.pass_)}/{total} passed ──")
    if report.fail:
        print("FAILURES:")
        for r in report.fail:
            print(f"  - {r.name}: {r.detail}")


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)
    families = [f.strip() for f in args.only.split(",") if f.strip()]
    unknown = [f for f in families if f not in FAMILIES]
    if unknown:
        print(f"unknown family/families: {unknown}; valid: {list(FAMILIES)}", file=sys.stderr)
        return 2

    print(f"Loading fixtures from {args.data_root}/ …")
    fx_full = load_fixtures(args.data_root)
    print(
        f"  papers={len(fx_full.papers)}  criteria={len(fx_full.criteria)}  "
        f"questions={len(fx_full.questions)}  indexer_records={len(fx_full.indexer_records)}  "
        f"indexer_fields={len(fx_full.indexer_fields)}  country_values={len(fx_full.country_values)}"
    )

    output_dir: Path | None = None
    if not args.no_output:
        ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
        output_dir = Path(args.output_root) / ts
        output_dir.mkdir(parents=True, exist_ok=True)
        print(f"  dumping responses to {output_dir}/")
        _dump_inputs(output_dir, fx_full, args)

    client = ApiClient(args.base_url, args.api_key, args.openai_key, output_dir=output_dir)

    # Verify the service is up before any AI calls
    try:
        health = client.get("/health")
        print(f"  health: {health}")
    except Exception as e:
        print(f"FATAL: cannot reach {args.base_url}/health — {e}", file=sys.stderr)
        client.close()
        return 3

    overall_fail = 0

    if not args.live_only:
        mock_report = RunReport()
        print("\n══ MOCK PASS ══")
        for fam in families:
            FAMILIES[fam](client, fx_full, mock=True, report=mock_report)
        _print_summary("MOCK", mock_report)
        overall_fail += len(mock_report.fail)
        if mock_report.fail and not args.continue_on_fail and not args.mock_only:
            print("\nAborting before LIVE pass (use --continue-on-fail to override).")
            client.close()
            return 1

    if not args.mock_only:
        if not args.openai_key:
            print(
                "\nNote: no OPENAI_API_KEY on host shell — relying on the backend's CRYSTALLISE_OPENAI_API_KEY"
                " from .env (loaded via docker-compose env_file). Pass --openai-key only to override per request.",
                file=sys.stderr,
            )
        live_report = RunReport()
        print(f"\n══ LIVE PASS (sample={args.live_sample}) ══")
        fx_live = sample_fixtures(fx_full, args.live_sample)
        for fam in families:
            FAMILIES[fam](client, fx_live, mock=False, report=live_report)
        _print_summary("LIVE", live_report)
        overall_fail += len(live_report.fail)

    client.close()
    return 0 if overall_fail == 0 else 1


if __name__ == "__main__":
    sys.exit(main())
