"""Tests for crystallise.indexer.schema_builder and the extraction pipeline."""

import pytest  # noqa: F401
import pandas as pd

from crystallise.indexer.pipeline import validate_extractions
from crystallise.indexer.schema_builder import (
    create_function_schema,
    validate_input_dataframe,
    validate_schema_fields,
)
from crystallise.prompts.indexer import build_field_injection_block


class TestCreateFunctionSchema:
    """Tests for schema builder with evidence + confidence wrappers."""

    def test_generates_typed_wrappers(self, sample_indexer_fields):
        schema = create_function_schema("extract_v2", "Extract", sample_indexer_fields)
        assert schema["name"] == "extract_v2"
        props = schema["parameters"]["properties"]
        assert "study_type" in props
        # Each field should be an object wrapper
        wrapper = props["study_type"]
        assert wrapper["type"] == "object"
        assert "value" in wrapper["properties"]
        assert "confidence" in wrapper["properties"]
        assert "evidence" in wrapper["properties"]
        assert wrapper["required"] == ["value", "confidence", "evidence"]

    def test_preserves_field_types_on_value(self, sample_indexer_fields):
        schema = create_function_schema("fn", "desc", sample_indexer_fields)
        props = schema["parameters"]["properties"]
        # string field → anyOf with string type and null
        value_any_of = props["study_type"]["properties"]["value"]["anyOf"]
        assert value_any_of[0]["type"] == "string"
        assert value_any_of[1]["type"] == "null"
        # array field → anyOf with array type and null
        outcomes_any_of = props["outcomes"]["properties"]["value"]["anyOf"]
        assert outcomes_any_of[0]["type"] == "array"
        assert outcomes_any_of[0]["items"]["type"] == "string"
        assert outcomes_any_of[1]["type"] == "null"

    def test_allows_null_on_value(self):
        fields = [{"name": "x", "description": "test", "data_type_primary": "number"}]
        schema = create_function_schema("fn", "desc", fields)
        value_any_of = schema["parameters"]["properties"]["x"]["properties"]["value"]["anyOf"]
        assert value_any_of[1]["type"] == "null"

    def test_full_depth_includes_reasoning(self):
        fields = [{"name": "x", "description": "test", "data_type_primary": "string", "depth": "full"}]
        schema = create_function_schema("fn", "desc", fields)
        wrapper_props = schema["parameters"]["properties"]["x"]["properties"]
        assert "reasoning" in wrapper_props
        assert "normalised_value" in wrapper_props

    def test_minimal_depth_excludes_reasoning(self):
        fields = [{"name": "x", "description": "test", "data_type_primary": "string", "depth": "minimal"}]
        schema = create_function_schema("fn", "desc", fields)
        wrapper_props = schema["parameters"]["properties"]["x"]["properties"]
        assert "reasoning" not in wrapper_props
        assert "normalised_value" not in wrapper_props

    def test_enum_mode_on_v2_value(self):
        fields = [
            {
                "name": "design",
                "description": "Study design",
                "data_type_primary": "string",
                "examples": ["RCT", "cohort"],
                "examples_mode": "enum",
            }
        ]
        schema = create_function_schema("fn", "desc", fields)
        value_schema = schema["parameters"]["properties"]["design"]["properties"]["value"]
        # enum is on the typed branch inside anyOf
        assert value_schema["anyOf"][0]["enum"] == ["RCT", "cohort"]

    def test_compound_array_type_normalized(self):
        """Compound types like 'array-string' are normalized in the indexer schema."""
        fields = [
            {"name": "tags", "description": "Tags", "data_type_primary": "array-string"},
        ]
        schema = create_function_schema("fn", "desc", fields)
        value_schema = schema["parameters"]["properties"]["tags"]["properties"]["value"]
        # anyOf[0] should be the array schema, anyOf[1] should be null
        assert value_schema["anyOf"][0]["type"] == "array"
        assert value_schema["anyOf"][0]["items"]["type"] == "string"
        assert value_schema["anyOf"][1]["type"] == "null"

    def test_empty_fields(self):
        schema = create_function_schema("fn", "desc", [])
        assert schema["parameters"]["properties"] == {}
        assert schema["parameters"]["required"] == []

    def test_evidence_schema_structure(self, sample_indexer_fields):
        schema = create_function_schema("fn", "desc", sample_indexer_fields)
        evidence = schema["parameters"]["properties"]["study_type"]["properties"]["evidence"]
        assert evidence["type"] == "array"
        item_props = evidence["items"]["properties"]
        assert "text" in item_props
        assert "section" in item_props
        assert item_props["section"]["enum"] == ["title", "abstract"]


class TestValidateV2Extractions:
    """Tests for server-side confidence validation rules."""

    def test_null_value_forces_zero_confidence(self):
        fields = {"study_type": {"value": None, "confidence": 0.8, "evidence": [{"text": "x"}]}}
        result = validate_extractions(fields)
        assert result["study_type"]["confidence"] == 0.0
        assert result["study_type"]["evidence"] == []

    def test_no_evidence_caps_confidence(self):
        fields = {"sample_size": {"value": 100, "confidence": 0.9, "evidence": []}}
        result = validate_extractions(fields)
        assert result["sample_size"]["confidence"] == 0.5

    def test_valid_extraction_unchanged(self):
        fields = {
            "study_type": {
                "value": "RCT",
                "confidence": 0.95,
                "evidence": [{"text": "randomized controlled trial"}],
            }
        }
        result = validate_extractions(fields)
        assert result["study_type"]["confidence"] == 0.95

    def test_low_confidence_with_evidence_unchanged(self):
        fields = {"x": {"value": "unclear", "confidence": 0.3, "evidence": [{"text": "some text"}]}}
        result = validate_extractions(fields)
        assert result["x"]["confidence"] == 0.3


class TestBuildFieldInjectionBlock:
    """Tests for prompt-layer field injection."""

    def test_basic_field_injection(self):
        fields = [
            {"name": "study_design", "description": "Type of study", "data_type_primary": "string"},
            {"name": "sample_size", "description": "Participant count", "data_type_primary": "number"},
        ]
        block = build_field_injection_block(fields)
        assert "Fields to extract:" in block
        assert "- study_design: Type of study" in block
        assert "Extract as a number." in block

    def test_array_type_instruction(self):
        fields = [{"name": "outcomes", "description": "Outcomes", "data_type_primary": "array"}]
        block = build_field_injection_block(fields)
        assert "Return as array." in block

    def test_compound_array_type_instruction(self):
        """Compound types like 'array-string' should also get array instruction."""
        fields = [{"name": "tags", "description": "Tags", "data_type_primary": "array-string"}]
        block = build_field_injection_block(fields)
        assert "Return as array." in block

    def test_boolean_type_instruction(self):
        fields = [{"name": "is_rct", "description": "Is RCT?", "data_type_primary": "boolean"}]
        block = build_field_injection_block(fields)
        assert "Return true/false." in block

    def test_guide_examples_in_block(self):
        fields = [
            {
                "name": "design",
                "description": "Design",
                "data_type_primary": "string",
                "examples": ["RCT", "cohort"],
                "examples_mode": "guide",
            }
        ]
        block = build_field_injection_block(fields)
        assert "Examples: RCT, cohort." in block

    def test_enum_examples_in_block(self):
        fields = [
            {
                "name": "design",
                "description": "Design",
                "data_type_primary": "string",
                "examples": ["RCT", "cohort"],
                "examples_mode": "enum",
            }
        ]
        block = build_field_injection_block(fields)
        assert "Must be one of: RCT, cohort." in block

    def test_empty_fields_returns_empty(self):
        assert build_field_injection_block([]) == ""


class TestValidateInputDataframe:
    def test_valid_dataframe_returns_true(self):
        df = pd.DataFrame(
            {
                "ID": ["1", "2"],
                "Title": ["Study A", "Study B"],
                "Abstract": ["Abstract A", "Abstract B"],
            }
        )
        valid, errors = validate_input_dataframe(df)
        assert valid is True
        assert errors == []

    def test_missing_column_returns_errors(self):
        df = pd.DataFrame({"ID": ["1"], "Title": ["A"]})
        valid, errors = validate_input_dataframe(df)
        assert valid is False
        assert any("Abstract" in e for e in errors)


class TestValidateSchemaFields:
    def test_valid_dataframe_returns_true(self):
        df = pd.DataFrame(
            {
                "name": ["study_type"],
                "description": ["Type of study"],
                "data_type_primary": ["string"],
                "data_type_secondary": ["NA"],
            }
        )
        valid, errors = validate_schema_fields(df)
        assert valid is True
        assert errors == []

    def test_missing_column_returns_errors(self):
        df = pd.DataFrame({"name": ["study_type"], "description": ["desc"]})
        valid, errors = validate_schema_fields(df)
        assert valid is False
        assert any("data_type_primary" in e for e in errors)


# ── Schema Strictness Tests ──


class TestSchemaStrictness:
    """Verify strict mode compliance on generated schemas."""

    def test_has_strict_true(self, sample_indexer_fields):
        schema = create_function_schema("fn", "desc", sample_indexer_fields)
        assert schema["strict"] is True

    def test_has_additional_properties_false_on_parameters(self, sample_indexer_fields):
        schema = create_function_schema("fn", "desc", sample_indexer_fields)
        assert schema["parameters"]["additionalProperties"] is False

    def test_has_additional_properties_false_on_field_wrappers(self, sample_indexer_fields):
        schema = create_function_schema("fn", "desc", sample_indexer_fields)
        for field_name in ["study_type", "population", "outcomes"]:
            wrapper = schema["parameters"]["properties"][field_name]
            assert wrapper["additionalProperties"] is False, f"{field_name} wrapper missing additionalProperties"

    def test_evidence_items_have_additional_properties_false(self, sample_indexer_fields):
        schema = create_function_schema("fn", "desc", sample_indexer_fields)
        evidence = schema["parameters"]["properties"]["study_type"]["properties"]["evidence"]
        assert evidence["items"]["additionalProperties"] is False

    def test_evidence_section_is_required(self, sample_indexer_fields):
        schema = create_function_schema("fn", "desc", sample_indexer_fields)
        evidence = schema["parameters"]["properties"]["study_type"]["properties"]["evidence"]
        assert "section" in evidence["items"]["required"]
        assert "text" in evidence["items"]["required"]

    def test_full_depth_requires_all_properties(self):
        fields = [{"name": "x", "description": "d", "data_type_primary": "string", "depth": "full"}]
        schema = create_function_schema("fn", "desc", fields)
        wrapper = schema["parameters"]["properties"]["x"]
        assert "reasoning" in wrapper["required"]
        assert "normalised_value" in wrapper["required"]
        assert "value" in wrapper["required"]
        assert "confidence" in wrapper["required"]
        assert "evidence" in wrapper["required"]

    def test_minimal_depth_does_not_require_reasoning(self):
        fields = [{"name": "x", "description": "d", "data_type_primary": "string", "depth": "minimal"}]
        schema = create_function_schema("fn", "desc", fields)
        wrapper = schema["parameters"]["properties"]["x"]
        assert "reasoning" not in wrapper["required"]
        assert "normalised_value" not in wrapper["required"]


# ── Enhanced Field Injection Tests ──


class TestFieldInjectionEnhancements:
    """Test the new extraction rule guidance added to field injection."""

    def test_all_fields_include_explicit_extraction_rule(self):
        fields = [{"name": "x", "description": "A field", "data_type_primary": "string"}]
        block = build_field_injection_block(fields)
        assert "Only extract if explicitly stated in the text." in block

    def test_enum_mode_includes_prefer_explicit(self):
        fields = [
            {
                "name": "design",
                "description": "Study design",
                "data_type_primary": "string",
                "examples": ["RCT", "cohort"],
                "examples_mode": "enum",
            }
        ]
        block = build_field_injection_block(fields)
        assert "Prefer explicit mention." in block
        assert "Must be one of: RCT, cohort." in block

    def test_guide_mode_does_not_include_prefer_explicit(self):
        fields = [
            {
                "name": "pop",
                "description": "Population",
                "data_type_primary": "string",
                "examples": ["adults", "children"],
                "examples_mode": "guide",
            }
        ]
        block = build_field_injection_block(fields)
        assert "Prefer explicit mention." not in block
        assert "Examples: adults, children." in block


# ── Field Suggestion Tests ──


class TestFieldSuggestion:
    """Test field suggestion async function."""

    @pytest.mark.asyncio
    async def test_returns_fields_and_warnings_tuple(self):
        from unittest.mock import AsyncMock, patch

        mock_response = '{"fields": [{"name": "study_design", "description": "Type", "data_type_primary": "string", "examples": ["RCT"], "extraction_difficulty": "low"}], "warnings": [{"field": "biomarker", "risk_level": "high", "reason": "Rarely in abstracts"}]}'

        with patch(
            "crystallise.llm.client.async_chat_completion",
            new_callable=AsyncMock,
            return_value=mock_response,
        ):
            from crystallise.indexer.field_suggestion import suggest_fields

            fields, warnings = await suggest_fields(project_description="test")
            assert len(fields) == 1
            assert fields[0]["name"] == "study_design"
            assert fields[0]["extraction_difficulty"] == "low"
            assert len(warnings) == 1
            assert warnings[0]["field"] == "biomarker"

    @pytest.mark.asyncio
    async def test_returns_empty_on_no_content(self):
        from unittest.mock import AsyncMock, patch

        with patch("crystallise.llm.client.async_chat_completion", new_callable=AsyncMock, return_value=None):
            from crystallise.indexer.field_suggestion import suggest_fields

            fields, warnings = await suggest_fields()
            assert fields == []
            assert warnings == []

    @pytest.mark.asyncio
    async def test_returns_empty_on_malformed_json(self):
        from unittest.mock import AsyncMock, patch

        with patch(
            "crystallise.llm.client.async_chat_completion",
            new_callable=AsyncMock,
            return_value="not json at all",
        ):
            from crystallise.indexer.field_suggestion import suggest_fields

            fields, warnings = await suggest_fields()
            assert fields == []
            assert warnings == []

    @pytest.mark.asyncio
    async def test_handles_list_response_without_fields_key(self):
        from unittest.mock import AsyncMock, patch

        mock_response = '[{"name": "x", "description": "y", "data_type_primary": "string"}]'

        with patch(
            "crystallise.llm.client.async_chat_completion",
            new_callable=AsyncMock,
            return_value=mock_response,
        ):
            from crystallise.indexer.field_suggestion import suggest_fields

            fields, warnings = await suggest_fields()
            assert len(fields) == 1
            assert warnings == []


# ── Refinement Merge Action Tests ──


class TestRefinementMergeAction:
    """Test merge action support in refinement."""

    def test_merge_action_accepted_by_schema(self):
        from api.schemas.indexer import FieldSuggestion, IndexerField

        s = FieldSuggestion(
            action="merge",
            field=IndexerField(name="study_design", description="Merged field"),
            rationale="study_type and study_design are the same concept",
            target_field_name="study_design",
        )
        assert s.action == "merge"
        assert s.target_field_name == "study_design"

    def test_all_actions_accepted_by_schema(self):
        from api.schemas.indexer import FieldSuggestion, IndexerField

        for action in ["add", "modify", "remove", "merge"]:
            s = FieldSuggestion(
                action=action,
                field=IndexerField(name="f"),
                rationale="test",
            )
            assert s.action == action

    def test_extraction_warning_schema(self):
        from api.schemas.indexer import ExtractionWarning

        w = ExtractionWarning(
            field="biomarker",
            risk_level="high",
            reason="Rarely stated in abstracts",
            suggested_fallback="Use broader category",
        )
        assert w.field == "biomarker"
        assert w.risk_level == "high"
        assert w.suggested_fallback == "Use broader category"

    def test_suggest_fields_response_with_warnings(self):
        from api.schemas.indexer import ExtractionWarning, IndexerField, SuggestFieldsResponse

        resp = SuggestFieldsResponse(
            fields=[IndexerField(name="x")],
            warnings=[ExtractionWarning(field="x", reason="test")],
        )
        assert len(resp.fields) == 1
        assert len(resp.warnings) == 1


# ── Examples Flattening Tests (API Router) ──


class TestExamplesFlattening:
    """Test that nested arrays in examples are flattened correctly."""

    def test_flatten_nested_example_arrays(self):
        """Simulate what the router does with nested arrays from LLM."""
        raw_examples: list = [["colorectal cancer", "breast cancer"], "melanoma", ["NSCLC"]]
        flat: list[str] = []
        for e in raw_examples:
            if isinstance(e, list):
                flat.extend(str(v) for v in e if v)
            elif e:
                flat.append(str(e))
        assert flat == ["colorectal cancer", "breast cancer", "melanoma", "NSCLC"]

    def test_flatten_already_flat_examples(self):
        raw_examples: list = ["RCT", "cohort", "case-control"]
        flat: list[str] = []
        for e in raw_examples:
            if isinstance(e, list):
                flat.extend(str(v) for v in e if v)
            elif e:
                flat.append(str(e))
        assert flat == ["RCT", "cohort", "case-control"]

    def test_flatten_skips_empty_values(self):
        raw_examples: list = ["", None, "valid", ["", "also_valid"]]
        flat: list[str] = []
        for e in raw_examples:
            if isinstance(e, list):
                flat.extend(str(v) for v in e if v)
            elif e:
                flat.append(str(e))
        assert flat == ["valid", "also_valid"]
