"""FastAPI endpoint tests for screening, criteria, and config routers."""

import time
from unittest.mock import MagicMock

import pytest
from fastapi.testclient import TestClient

from api.main import app

client = TestClient(app)


def _mock_openai_client():
    return MagicMock()


@pytest.fixture(autouse=True)
def override_openai_dep():
    from api.dependencies import get_openai_client

    app.dependency_overrides[get_openai_client] = _mock_openai_client
    yield
    app.dependency_overrides.clear()


@pytest.fixture(autouse=True)
def reset_config_registry():
    """Reset the global config registry between tests."""
    import crystallise.config.registry as reg_mod

    reg_mod._registry = None
    yield
    reg_mod._registry = None


@pytest.fixture(autouse=True)
def reset_screening_router():
    """Reset screening router state between tests."""
    import api.routers.screening as scr

    scr._active_jobs.clear()
    scr._table_created = False
    scr._db_available = True
    yield
    scr._active_jobs.clear()


class TestScreeningJobEndpoints:
    def test_post_creates_job(self):
        response = client.post(
            "/screening/jobs",
            json={
                "papers": [
                    {"title": "Study A", "abstract": "Randomized trial of drug X"},
                    {"title": "Study B", "abstract": "Review of existing literature"},
                ],
                "criteria": [
                    {"name": "Population", "type": "include", "value": "Adults"},
                ],
                "questions": ["Is drug X effective?"],
                "mock": True,
            },
        )
        assert response.status_code == 200
        data = response.json()
        assert "job_id" in data
        assert data["status"] == "pending"

    def test_get_job_status(self):
        # Create a job first
        create_resp = client.post(
            "/screening/jobs",
            json={
                "papers": [{"title": "Study", "abstract": "A clinical trial"}],
                "mock": True,
            },
        )
        job_id = create_resp.json()["job_id"]

        # Poll for completion (mock is fast)
        for _ in range(20):
            status_resp = client.get(f"/screening/jobs/{job_id}")
            assert status_resp.status_code == 200
            data = status_resp.json()
            assert data["job_id"] == job_id
            assert "status" in data
            if data["status"] in ("completed", "failed"):
                break
            time.sleep(0.1)

    def test_get_nonexistent_job_returns_404(self):
        response = client.get("/screening/jobs/nonexistent-id")
        assert response.status_code == 404


class TestCriteriaEndpoints:
    def test_post_generate_with_mock(self):
        response = client.post(
            "/criteria/generate",
            json={
                "project_description": "A systematic review of cancer treatments in adults",
                "research_questions": ["What is the efficacy of drug X?"],
                "mock": True,
            },
        )
        assert response.status_code == 200
        data = response.json()
        assert "criteria" in data
        assert isinstance(data["criteria"], list)
        assert len(data["criteria"]) > 0
        # Verify structure
        for criterion in data["criteria"]:
            assert "category" in criterion
            assert "text" in criterion


class TestCriteriaPicoEndpoint:
    def test_post_pico_with_mock(self):
        response = client.post(
            "/criteria/pico",
            json={
                "project_description": "A systematic review of cancer treatments in adults",
                "research_questions": ["What is the efficacy of drug X?"],
                "mock": True,
            },
        )
        assert response.status_code == 200
        data = response.json()
        assert "elements" in data
        assert isinstance(data["elements"], dict)
        assert "population" in data["elements"]
        assert "intervention" in data["elements"]
        assert "comparison" in data["elements"]
        assert "outcome" in data["elements"]
        assert "gap_flags" in data
        assert isinstance(data["gap_flags"], list)


class TestCriteriaRefineContextEndpoint:
    def test_post_refine_context_with_mock(self):
        response = client.post(
            "/criteria/refine-context",
            json={
                "description": "A systematic review of drug efficacy",
                "research_questions": ["Is drug X effective?"],
                "mock": True,
            },
        )
        assert response.status_code == 200
        data = response.json()
        assert "refined_description" in data
        assert "refined_research_questions" in data
        assert "explanation" in data
        assert isinstance(data["refined_research_questions"], list)
        assert len(data["refined_research_questions"]) > 0

    def test_refine_context_preserves_input(self):
        response = client.post(
            "/criteria/refine-context",
            json={
                "description": "Review of drug X",
                "mock": True,
            },
        )
        assert response.status_code == 200
        data = response.json()
        assert "Review of drug X" in data["refined_description"]


class TestCriteriaRefineEndpoint:
    def test_post_refine_with_mock(self):
        response = client.post(
            "/criteria/refine",
            json={
                "current_criteria": [
                    {"category": "Population", "text": "Adults only"},
                ],
                "conflicts": [
                    {"paper_title": "Study A", "decision_a": "include", "decision_b": "exclude"},
                ],
                "mock": True,
            },
        )
        assert response.status_code == 200
        data = response.json()
        assert "criteria" in data
        assert isinstance(data["criteria"], list)
        assert len(data["criteria"]) > 0
        for c in data["criteria"]:
            assert "category" in c
            assert "text" in c
            assert "criterion_type" in c


class TestCriteriaAnalyzeQuestionEndpoint:
    def test_post_analyze_question_with_mock(self):
        response = client.post(
            "/v1/criteria/analyze-question",
            json={
                "research_question": "What is the effect of exercise on depression in adults?",
                "mock": True,
            },
        )
        assert response.status_code == 200
        data = response.json()
        assert data["status"] in ("ready", "could_improve")
        assert isinstance(data["missing_elements"], list)
        assert all(isinstance(e, str) for e in data["missing_elements"])
        assert isinstance(data["suggestion"], str)
        assert data["suggestion"].strip()

    def test_post_analyze_question_works_on_unversioned_alias(self):
        response = client.post(
            "/criteria/analyze-question",
            json={"research_question": "Anything", "mock": True},
        )
        assert response.status_code == 200


class TestCriteriaConsolidateEndpoint:
    def test_post_consolidate_with_mock(self):
        response = client.post(
            "/criteria/consolidate",
            json={
                "criteria": [
                    {"category": "Population", "text": "Adults 18+"},
                    {"category": "Population", "text": "Adult participants over 18"},
                ],
                "project_description": "Review of treatments",
                "mock": True,
            },
        )
        assert response.status_code == 200
        data = response.json()
        assert "duplicate_groups" in data
        assert "consolidation_proposals" in data
        assert "warnings" in data
        assert isinstance(data["warnings"], list)


class TestConfigEndpoints:
    def test_get_services_list(self):
        response = client.get("/config/services")
        assert response.status_code == 200
        data = response.json()
        assert isinstance(data, list)
        assert len(data) > 0
        service_ids = [c["service_id"] for c in data]
        assert "screening" in service_ids

    def test_get_screening_config(self):
        response = client.get("/config/services/screening")
        assert response.status_code == 200
        data = response.json()
        assert data["service_id"] == "screening"
        assert "model" in data
        assert "temperature" in data
        assert "max_output_tokens" in data

    def test_put_updates_screening_model(self):
        response = client.put(
            "/config/services/screening",
            json={
                "model": "gpt-5-mini",
            },
        )
        assert response.status_code == 200
        data = response.json()
        assert data["model"] == "gpt-5-mini"

        # Verify it persists
        get_resp = client.get("/config/services/screening")
        assert get_resp.json()["model"] == "gpt-5-mini"

    def test_put_updates_temperature(self):
        response = client.put(
            "/config/services/screening",
            json={
                "temperature": 0.5,
            },
        )
        assert response.status_code == 200
        assert response.json()["temperature"] == 0.5

    def test_get_unknown_service_returns_default(self):
        response = client.get("/config/services/unknown_service")
        assert response.status_code == 200
        data = response.json()
        assert data["service_id"] == "unknown_service"
        assert data["model"] == "gpt-5-nano"

    def test_get_prompt_registry(self):
        response = client.get("/config/prompts")
        assert response.status_code == 200
        data = response.json()
        assert isinstance(data, list)
        assert len(data) >= 20  # 23 prompts in registry
        # Check structure
        first = data[0]
        assert "name" in first
        assert "service" in first
        assert "description" in first
        assert "has_variables" in first
        assert "system_or_user" in first

    def test_prompt_registry_contains_key_prompts(self):
        response = client.get("/config/prompts")
        names = {p["name"] for p in response.json()}
        assert "screening.labelling" in names
        assert "criteria.exclusion_generation" in names
        assert "indexer.pipeline" in names
