"""Tests for the centralised prompt registry."""

from crystallise.prompts.registry import PROMPT_REGISTRY, PromptInfo, list_prompts


class TestPromptRegistry:
    def test_registry_is_not_empty(self):
        assert len(PROMPT_REGISTRY) >= 20

    def test_all_entries_are_prompt_info(self):
        for key, info in PROMPT_REGISTRY.items():
            assert isinstance(info, PromptInfo), f"{key} is not PromptInfo"

    def test_all_entries_have_required_fields(self):
        for key, info in PROMPT_REGISTRY.items():
            assert info.name, f"{key} missing name"
            assert info.service, f"{key} missing service"
            assert info.description, f"{key} missing description"
            assert info.system_or_user in ("system", "user", "both"), (
                f"{key} has invalid system_or_user: {info.system_or_user}"
            )

    def test_name_matches_dict_key(self):
        for key, info in PROMPT_REGISTRY.items():
            assert key == info.name, f"Key {key} != info.name {info.name}"

    def test_services_are_valid(self):
        valid_services = {"screening", "criteria", "indexer"}
        for key, info in PROMPT_REGISTRY.items():
            assert info.service in valid_services, f"{key} has unknown service: {info.service}"

    def test_screening_prompts_present(self):
        screening = [k for k in PROMPT_REGISTRY if k.startswith("screening.")]
        assert len(screening) >= 4  # labelling, reasoning, clustering, cluster_selection

    def test_criteria_prompts_present(self):
        criteria = [k for k in PROMPT_REGISTRY if k.startswith("criteria.")]
        assert len(criteria) >= 5  # exclusion, inclusion, pico, refinement, consolidation

    def test_question_analysis_prompt_registered(self):
        assert "criteria.question_analysis" in PROMPT_REGISTRY
        info = PROMPT_REGISTRY["criteria.question_analysis"]
        assert info.service == "criteria"
        assert info.system_or_user in ("system", "user", "both")

    def test_indexer_prompts_present(self):
        indexer = [k for k in PROMPT_REGISTRY if k.startswith("indexer.")]
        assert len(indexer) >= 3  # pipeline, refinement, grouping


class TestListPrompts:
    def test_returns_list_of_dicts(self):
        result = list_prompts()
        assert isinstance(result, list)
        assert len(result) == len(PROMPT_REGISTRY)

    def test_dict_structure(self):
        result = list_prompts()
        for item in result:
            assert isinstance(item, dict)
            assert "name" in item
            assert "service" in item
            assert "description" in item
            assert "has_variables" in item
            assert "system_or_user" in item

    def test_serialisable(self):
        """list_prompts output must be JSON-serialisable for the API."""
        import json

        result = list_prompts()
        serialised = json.dumps(result)
        assert isinstance(serialised, str)
        roundtrip = json.loads(serialised)
        assert len(roundtrip) == len(result)


class TestPromptImports:
    """Verify that prompts can be imported from the centralised location."""

    def test_screening_imports(self):
        from crystallise.prompts.screening import make_labelling_system_message

        assert callable(make_labelling_system_message)

    def test_screening_backwards_compat(self):
        from crystallise.screening.sysmsg import make_labelling_system_message

        assert callable(make_labelling_system_message)

    def test_criteria_imports(self):
        from crystallise.prompts.criteria import exclusion_generation_system_prompt, CRITERIA_SCHEMA

        assert callable(exclusion_generation_system_prompt)
        assert isinstance(CRITERIA_SCHEMA, dict)

    def test_criteria_backwards_compat(self):
        from crystallise.criteria.prompts import CRITERIA_SCHEMA

        assert isinstance(CRITERIA_SCHEMA, dict)

    def test_indexer_imports(self):
        from crystallise.prompts.indexer import PIPELINE_SYSTEM_PROMPT, PIPELINE_USER_PROMPT

        assert isinstance(PIPELINE_SYSTEM_PROMPT, str)
        assert isinstance(PIPELINE_USER_PROMPT, str)

    def test_indexer_backwards_compat(self):
        from crystallise.indexer.pipeline import DEFAULT_SYSTEM_PROMPT, DEFAULT_USER_PROMPT

        assert isinstance(DEFAULT_SYSTEM_PROMPT, str)
        assert isinstance(DEFAULT_USER_PROMPT, str)
