"""Tests for screening pipeline utilities (no LLM calls)."""

from crystallise.screening.pipeline import (
    ScreeningError,
    convert_criteria_format,
)


class TestConvertCriteriaFormat:
    def test_converts_flat_list_to_nested_dict(self):
        criteria = [
            {"name": "Population", "type": "include", "value": "Adults 18+"},
            {"name": "Population", "type": "exclude", "value": "Children"},
            {"name": "Study Design", "type": "exclude", "value": "Case reports"},
        ]
        result = convert_criteria_format(criteria)
        assert "Population" in result
        assert "Adults 18+" in result["Population"]["include"]
        assert "Children" in result["Population"]["exclude"]
        assert "Study Design" in result
        assert "Case reports" in result["Study Design"]["exclude"]

    def test_empty_list_returns_empty_dict(self):
        result = convert_criteria_format([])
        assert result == {}

    def test_skips_empty_values(self):
        criteria = [
            {"name": "Population", "type": "include", "value": "Adults"},
            {"name": "Population", "type": "include", "value": ""},
        ]
        result = convert_criteria_format(criteria)
        assert result["Population"]["include"] == ["Adults"]

    def test_defaults_to_exclude_type(self):
        criteria = [{"name": "Other", "value": "Something"}]
        result = convert_criteria_format(criteria)
        assert "Something" in result["Other"]["exclude"]

    def test_defaults_to_other_name(self):
        criteria = [{"type": "include", "value": "Generic criterion"}]
        result = convert_criteria_format(criteria)
        assert "Other" in result
        assert "Generic criterion" in result["Other"]["include"]


class TestScreeningError:
    def test_from_exception_classifies_auth_errors(self):
        class AuthenticationError(Exception):
            pass
        auth_exc = AuthenticationError("Invalid API key 401")
        error = ScreeningError.from_exception(auth_exc)
        assert error.category == "auth"
        assert "authentication" in str(error).lower()

    def test_from_exception_classifies_rate_limit_errors(self):
        class RateLimitError(Exception):
            pass
        rate_exc = RateLimitError("Too many requests 429")
        error = ScreeningError.from_exception(rate_exc)
        assert error.category == "rate_limit"
        assert "rate limit" in str(error).lower()

    def test_from_exception_classifies_timeout_errors(self):
        class TimeoutError(Exception):
            pass
        timeout_exc = TimeoutError("Connection timed out")
        error = ScreeningError.from_exception(timeout_exc)
        assert error.category == "network"

    def test_from_exception_sanitizes_api_keys(self):
        exc = RuntimeError("Error with key sk-abcdefghijklmnop1234567890 in request")
        error = ScreeningError.from_exception(exc)
        assert "sk-abcdefghijklmnop1234567890" not in str(error)
        assert "sk-***" in str(error)

    def test_from_exception_unknown_category(self):
        exc = ValueError("Something went wrong")
        error = ScreeningError.from_exception(exc)
        assert error.category == "unknown"
        assert "Something went wrong" in str(error)

    def test_screening_error_init(self):
        error = ScreeningError("Test message", category="test")
        assert str(error) == "Test message"
        assert error.category == "test"
