"""
Unified retry logic merging patterns from both repos:
- AI_Screening_UI: async exponential backoff (2s, 8s, 30s) for chat completions
- CrystalliseAppsToolbox: sync backoff with jitter for Responses API / vector store ops
"""

from __future__ import annotations

import asyncio
import logging
import random
import time
from dataclasses import dataclass, field
from typing import Any, Callable, TypeVar

import openai

from .errors import classify_openai_error, RETRYABLE_CATEGORIES

logger = logging.getLogger(__name__)

T = TypeVar("T")

# HTTP status codes safe to retry
RETRIABLE_STATUS_CODES = {429, 500, 502, 503, 504}

# OpenAI exception types safe to retry
RETRYABLE_EXCEPTIONS = (
    openai.RateLimitError,
    openai.APITimeoutError,
    openai.APIConnectionError,
    openai.InternalServerError,
)


@dataclass
class RetryConfig:
    """Configuration for retry behavior."""

    max_retries: int = 3
    delays: list[float] = field(default_factory=lambda: [2.0, 8.0, 30.0])
    base_backoff: float = 1.5
    jitter: float = 0.25


async def async_retry_with_backoff(
    func: Callable[..., Any],
    config: RetryConfig | None = None,
) -> Any:
    """
    Async retry with exponential backoff for OpenAI API calls.

    Retries on transient errors (429, timeouts, 5xx).
    Raises AuthenticationError immediately.
    """
    cfg = config or RetryConfig()
    last_exception = None

    for attempt in range(cfg.max_retries):
        try:
            return await func()
        except openai.AuthenticationError:
            raise
        except RETRYABLE_EXCEPTIONS as e:
            last_exception = e
            delay = cfg.delays[min(attempt, len(cfg.delays) - 1)]
            category = classify_openai_error(e)
            logger.warning(
                "OpenAI transient error (attempt %d/%d, %s): %s: %s — retrying in %.1fs",
                attempt + 1,
                cfg.max_retries,
                category.value,
                type(e).__name__,
                e,
                delay,
            )
            await asyncio.sleep(delay)
        except openai.LengthFinishReasonError:
            raise
        except Exception:
            raise

    if last_exception:
        raise last_exception
    raise RuntimeError("Retry exhausted with no exception captured")


def call_with_retries(
    func: Callable[..., T],
    *args: Any,
    max_retries: int = 5,
    base_backoff: float = 1.5,
    jitter: float = 0.25,
    **kwargs: Any,
) -> T:
    """
    Sync retry with exponential backoff + jitter.

    Used for OpenAI file/vector-store operations (Responses API pattern).
    """
    for attempt in range(1, max_retries + 1):
        try:
            return func(*args, **kwargs)
        except Exception as exc:
            category = classify_openai_error(exc)
            retriable = category in RETRYABLE_CATEGORIES
            if attempt >= max_retries or not retriable:
                raise
            delay = (base_backoff ** (attempt - 1)) + random.uniform(0, jitter)
            logger.warning(
                "Sync retry attempt %d/%d (%s): %s — sleeping %.1fs",
                attempt, max_retries, category.value, exc, delay,
            )
            time.sleep(delay)

    raise RuntimeError("Unreachable")  # pragma: no cover
