"""Tests for crystallise.llm.retry."""
import asyncio

import pytest
from unittest.mock import MagicMock, patch

import openai

from crystallise.llm.retry import (
    call_with_retries,
    async_retry_with_backoff,
    RetryConfig,
    RETRIABLE_STATUS_CODES,
    RETRYABLE_EXCEPTIONS,
)


class TestCallWithRetries:
    """Sync retry (call_with_retries) — used for Responses API."""

    def test_succeeds_on_first_try(self):
        fn = MagicMock(return_value="ok")
        result = call_with_retries(fn, max_retries=3)
        assert result == "ok"
        assert fn.call_count == 1

    @patch("crystallise.llm.retry.time.sleep")
    def test_retries_on_retriable_status(self, mock_sleep):
        exc = Exception("rate limited")
        exc.status_code = 429
        fn = MagicMock(side_effect=[exc, exc, "ok"])
        result = call_with_retries(fn, max_retries=5, base_backoff=1.0, jitter=0.0)
        assert result == "ok"
        assert fn.call_count == 3

    @patch("crystallise.llm.retry.time.sleep")
    def test_raises_after_max_retries_exhausted(self, mock_sleep):
        exc = Exception("rate limited")
        exc.status_code = 429
        fn = MagicMock(side_effect=exc)
        with pytest.raises(Exception, match="rate limited"):
            call_with_retries(fn, max_retries=3, base_backoff=1.0, jitter=0.0)
        assert fn.call_count == 3

    def test_raises_immediately_on_non_retriable_error(self):
        exc = Exception("bad request")
        exc.status_code = 400
        fn = MagicMock(side_effect=exc)
        with pytest.raises(Exception, match="bad request"):
            call_with_retries(fn, max_retries=5)
        assert fn.call_count == 1

    @patch("crystallise.llm.retry.time.sleep")
    def test_retries_on_500_status(self, mock_sleep):
        exc = Exception("server error")
        exc.status_code = 500
        fn = MagicMock(side_effect=[exc, "ok"])
        result = call_with_retries(fn, max_retries=3, base_backoff=1.0, jitter=0.0)
        assert result == "ok"
        assert fn.call_count == 2

    @patch("crystallise.llm.retry.time.sleep")
    def test_retries_on_502_status(self, mock_sleep):
        exc = Exception("bad gateway")
        exc.status_code = 502
        fn = MagicMock(side_effect=[exc, "ok"])
        result = call_with_retries(fn, max_retries=3, base_backoff=1.0, jitter=0.0)
        assert result == "ok"

    def test_raises_immediately_on_401_status(self):
        exc = Exception("unauthorized")
        exc.status_code = 401
        fn = MagicMock(side_effect=exc)
        with pytest.raises(Exception, match="unauthorized"):
            call_with_retries(fn, max_retries=5)
        assert fn.call_count == 1

    def test_retriable_status_codes_constant(self):
        assert 429 in RETRIABLE_STATUS_CODES
        assert 500 in RETRIABLE_STATUS_CODES
        assert 502 in RETRIABLE_STATUS_CODES
        assert 503 in RETRIABLE_STATUS_CODES
        assert 504 in RETRIABLE_STATUS_CODES
        assert 400 not in RETRIABLE_STATUS_CODES
        assert 401 not in RETRIABLE_STATUS_CODES


class TestAsyncRetryWithBackoff:
    """Async retry (async_retry_with_backoff) — used for Chat Completions."""

    @pytest.fixture
    def config(self):
        return RetryConfig(max_retries=3, delays=[0.01, 0.01, 0.01])

    def test_succeeds_on_first_try(self, config):
        async def fn():
            return "ok"

        result = asyncio.run(
            async_retry_with_backoff(fn, config)
        )
        assert result == "ok"

    def test_retries_on_rate_limit_error(self, config):
        call_count = 0

        async def fn():
            nonlocal call_count
            call_count += 1
            if call_count < 3:
                raise openai.RateLimitError(
                    message="rate limited",
                    response=MagicMock(status_code=429),
                    body=None,
                )
            return "ok"

        result = asyncio.run(
            async_retry_with_backoff(fn, config)
        )
        assert result == "ok"
        assert call_count == 3

    def test_retries_on_timeout_error(self, config):
        call_count = 0

        async def fn():
            nonlocal call_count
            call_count += 1
            if call_count < 2:
                raise openai.APITimeoutError(request=MagicMock())
            return "ok"

        result = asyncio.run(
            async_retry_with_backoff(fn, config)
        )
        assert result == "ok"
        assert call_count == 2

    def test_retries_on_connection_error(self, config):
        call_count = 0

        async def fn():
            nonlocal call_count
            call_count += 1
            if call_count < 2:
                raise openai.APIConnectionError(request=MagicMock())
            return "ok"

        result = asyncio.run(
            async_retry_with_backoff(fn, config)
        )
        assert result == "ok"
        assert call_count == 2

    def test_retries_on_internal_server_error(self, config):
        call_count = 0

        async def fn():
            nonlocal call_count
            call_count += 1
            if call_count < 2:
                raise openai.InternalServerError(
                    message="server error",
                    response=MagicMock(status_code=500),
                    body=None,
                )
            return "ok"

        result = asyncio.run(
            async_retry_with_backoff(fn, config)
        )
        assert result == "ok"

    def test_raises_immediately_on_auth_error(self, config):
        async def fn():
            raise openai.AuthenticationError(
                message="invalid key",
                response=MagicMock(status_code=401),
                body=None,
            )

        with pytest.raises(openai.AuthenticationError):
            asyncio.run(
                async_retry_with_backoff(fn, config)
            )

    def test_raises_immediately_on_bad_request(self, config):
        async def fn():
            raise openai.BadRequestError(
                message="bad request",
                response=MagicMock(status_code=400),
                body=None,
            )

        with pytest.raises(openai.BadRequestError):
            asyncio.run(
                async_retry_with_backoff(fn, config)
            )

    def test_exhaustion_raises_last_exception(self):
        config = RetryConfig(max_retries=2, delays=[0.01, 0.01])

        async def fn():
            raise openai.RateLimitError(
                message="rate limited",
                response=MagicMock(status_code=429),
                body=None,
            )

        with pytest.raises(openai.RateLimitError):
            asyncio.run(
                async_retry_with_backoff(fn, config)
            )

    def test_retryable_exceptions_tuple(self):
        assert openai.RateLimitError in RETRYABLE_EXCEPTIONS
        assert openai.APITimeoutError in RETRYABLE_EXCEPTIONS
        assert openai.APIConnectionError in RETRYABLE_EXCEPTIONS
        assert openai.InternalServerError in RETRYABLE_EXCEPTIONS


class TestRetryConfig:
    def test_default_config(self):
        config = RetryConfig()
        assert config.max_retries == 3
        assert config.delays == [2.0, 8.0, 30.0]
        assert config.base_backoff == 1.5
        assert config.jitter == 0.25

    def test_custom_config(self):
        config = RetryConfig(max_retries=5, delays=[1.0, 2.0], base_backoff=2.0, jitter=0.5)
        assert config.max_retries == 5
        assert config.delays == [1.0, 2.0]
