"""Tests for crystallise.llm.errors — unified error taxonomy."""
from unittest.mock import MagicMock

import openai

from crystallise.llm.errors import (
    ErrorCategory,
    LLMError,
    RETRYABLE_CATEGORIES,
    classify_openai_error,
)


class TestErrorCategory:
    def test_all_categories_exist(self):
        assert ErrorCategory.TRANSIENT == "transient"
        assert ErrorCategory.RATE_LIMIT == "rate_limit"
        assert ErrorCategory.TIMEOUT == "timeout"
        assert ErrorCategory.AUTH == "auth"
        assert ErrorCategory.VALIDATION == "validation"
        assert ErrorCategory.UNKNOWN == "unknown"

    def test_retryable_categories(self):
        assert ErrorCategory.TRANSIENT in RETRYABLE_CATEGORIES
        assert ErrorCategory.RATE_LIMIT in RETRYABLE_CATEGORIES
        assert ErrorCategory.TIMEOUT in RETRYABLE_CATEGORIES
        assert ErrorCategory.AUTH not in RETRYABLE_CATEGORIES
        assert ErrorCategory.VALIDATION not in RETRYABLE_CATEGORIES
        assert ErrorCategory.UNKNOWN not in RETRYABLE_CATEGORIES


class TestLLMError:
    def test_basic_creation(self):
        err = LLMError("test error", ErrorCategory.TRANSIENT)
        assert str(err) == "test error"
        assert err.category == ErrorCategory.TRANSIENT
        assert err.original is None

    def test_with_original_exception(self):
        original = ValueError("original cause")
        err = LLMError("wrapped", ErrorCategory.VALIDATION, original=original)
        assert err.original is original
        assert err.category == ErrorCategory.VALIDATION

    def test_retryable_property(self):
        assert LLMError("x", ErrorCategory.TRANSIENT).retryable is True
        assert LLMError("x", ErrorCategory.RATE_LIMIT).retryable is True
        assert LLMError("x", ErrorCategory.TIMEOUT).retryable is True
        assert LLMError("x", ErrorCategory.AUTH).retryable is False
        assert LLMError("x", ErrorCategory.VALIDATION).retryable is False
        assert LLMError("x", ErrorCategory.UNKNOWN).retryable is False

    def test_repr(self):
        err = LLMError("test", ErrorCategory.AUTH)
        r = repr(err)
        assert "auth" in r
        assert "test" in r

    def test_is_exception(self):
        err = LLMError("test", ErrorCategory.TRANSIENT)
        assert isinstance(err, Exception)


class TestClassifyOpenaiError:
    """Map every known OpenAI SDK exception to the correct category."""

    def test_rate_limit_error(self):
        exc = openai.RateLimitError(
            message="rate limited",
            response=MagicMock(status_code=429),
            body=None,
        )
        assert classify_openai_error(exc) == ErrorCategory.RATE_LIMIT

    def test_timeout_error(self):
        exc = openai.APITimeoutError(request=MagicMock())
        assert classify_openai_error(exc) == ErrorCategory.TIMEOUT

    def test_connection_error(self):
        exc = openai.APIConnectionError(request=MagicMock())
        assert classify_openai_error(exc) == ErrorCategory.TRANSIENT

    def test_internal_server_error(self):
        exc = openai.InternalServerError(
            message="server error",
            response=MagicMock(status_code=500),
            body=None,
        )
        assert classify_openai_error(exc) == ErrorCategory.TRANSIENT

    def test_authentication_error(self):
        exc = openai.AuthenticationError(
            message="invalid key",
            response=MagicMock(status_code=401),
            body=None,
        )
        assert classify_openai_error(exc) == ErrorCategory.AUTH

    def test_bad_request_error(self):
        exc = openai.BadRequestError(
            message="bad request",
            response=MagicMock(status_code=400),
            body=None,
        )
        assert classify_openai_error(exc) == ErrorCategory.VALIDATION

    def test_permission_denied_error(self):
        exc = openai.PermissionDeniedError(
            message="forbidden",
            response=MagicMock(status_code=403),
            body=None,
        )
        assert classify_openai_error(exc) == ErrorCategory.AUTH

    def test_not_found_error(self):
        exc = openai.NotFoundError(
            message="not found",
            response=MagicMock(status_code=404),
            body=None,
        )
        assert classify_openai_error(exc) == ErrorCategory.VALIDATION

    def test_unknown_exception(self):
        exc = RuntimeError("something unexpected")
        assert classify_openai_error(exc) == ErrorCategory.UNKNOWN

    def test_fallback_status_code_429(self):
        exc = Exception("rate limited")
        exc.status_code = 429
        assert classify_openai_error(exc) == ErrorCategory.RATE_LIMIT

    def test_fallback_status_code_500(self):
        exc = Exception("server error")
        exc.status_code = 500
        assert classify_openai_error(exc) == ErrorCategory.TRANSIENT

    def test_fallback_status_code_502(self):
        exc = Exception("bad gateway")
        exc.status_code = 502
        assert classify_openai_error(exc) == ErrorCategory.TRANSIENT

    def test_fallback_status_code_401(self):
        exc = Exception("unauthorized")
        exc.status_code = 401
        assert classify_openai_error(exc) == ErrorCategory.AUTH

    def test_fallback_status_code_400(self):
        exc = Exception("bad request")
        exc.status_code = 400
        assert classify_openai_error(exc) == ErrorCategory.VALIDATION

    def test_fallback_status_attribute(self):
        """Some SDK versions use .status instead of .status_code."""
        exc = Exception("error")
        exc.status = 503
        assert classify_openai_error(exc) == ErrorCategory.TRANSIENT

    def test_no_status_code_returns_unknown(self):
        exc = Exception("plain error")
        assert classify_openai_error(exc) == ErrorCategory.UNKNOWN
