"""
Unified OpenAI client wrapper supporting both Chat Completions and Responses API.

Provides async and sync interfaces with:
- Semaphore-based concurrency control
- Fresh AsyncOpenAI client per request (avoids connection pool exhaustion)
- Structured output via Pydantic schemas
- Progress callbacks and cancellation support
"""

from __future__ import annotations

import asyncio
import logging
import os
from typing import Any, Callable, Optional, Type

from openai import AsyncOpenAI, OpenAI
from pydantic import BaseModel

from .retry import RetryConfig, async_retry_with_backoff

logger = logging.getLogger(__name__)


class AuthError(Exception):
    """Raised when OpenAI returns 401/403 (invalid or unauthorized API key)."""


async def async_chat_completion(
    *,
    system_message: str,
    prompt: str,
    output_schema: Type[BaseModel] | None = None,
    max_completion_tokens: int = 4096,
    model: str = "gpt-5-nano",
    api_key: str | None = None,
    temperature: float | None = None,
    retry_config: RetryConfig | None = None,
) -> str | None:
    """
    Send a single async Chat Completions request with structured output.

    Returns parsed content string, or None on non-retryable error.
    Raises AuthError for 401/403.
    """
    import openai

    key = api_key or os.environ.get("OPENAI_API_KEY") or os.environ.get("CRYSTALLISE_OPENAI_API_KEY")
    if not key:
        logger.error("OPENAI_API_KEY not set")
        return None

    cfg = retry_config or RetryConfig()

    # Preflight capability validation
    from crystallise.config.model_capabilities import validate_request, should_strip_temperature

    validate_request(
        model,
        max_output_tokens=max_completion_tokens,
        needs_structured_output=output_schema is not None,
    )

    async def _call() -> str | None:
        async with AsyncOpenAI(api_key=key) as client:
            kwargs: dict[str, Any] = {
                "model": model,
                "messages": [
                    {"role": "system", "content": system_message},
                    {"role": "user", "content": prompt},
                ],
                "max_completion_tokens": max_completion_tokens,
                "top_p": 1,
                "store": False,
            }
            if output_schema is not None:
                kwargs["response_format"] = output_schema
            if temperature is not None and not should_strip_temperature(model):
                kwargs["temperature"] = temperature

            if output_schema is not None:
                response = await client.chat.completions.parse(**kwargs)
            else:
                response = await client.chat.completions.create(**kwargs)

            return response.choices[0].message.content

    from .logging import llm_logger, Timer

    try:
        with Timer() as timer:
            result = await async_retry_with_backoff(_call, cfg)
        llm_logger.log_call(model=model, latency_ms=timer.elapsed_ms, service="chat_completion")
        return result
    except openai.AuthenticationError as e:
        raise AuthError(f"OpenAI authentication failed: {e}. Check your API key.") from e
    except openai.LengthFinishReasonError as e:
        try:
            content = e.completion.choices[0].message.content
            if content:
                return content
        except Exception:
            pass
        logger.warning("LengthFinishReasonError: response truncated")
        return None
    except Exception as e:
        from .errors import classify_openai_error

        cat = classify_openai_error(e)
        llm_logger.log_call(model=model, latency_ms=0, error_category=cat.value, service="chat_completion")
        logger.error("OpenAI async error:", exc_info=True)
        return None


async def async_batch_chat_completions(
    *,
    system_messages: list[str],
    prompts: list[str],
    output_schemas: list[Type[BaseModel] | None],
    max_completion_tokens: int = 4096,
    model: str = "gpt-5-nano",
    max_concurrent: int = 10,
    on_progress: Callable[[int, int], None] | None = None,
    cancel_event: Optional[Any] = None,
    api_key: str | None = None,
    retry_config: RetryConfig | None = None,
) -> list[str | None]:
    """
    Send multiple Chat Completions requests with semaphore-based concurrency.

    Args:
        on_progress: callback(completed, total)
        cancel_event: threading.Event — if set, pending tasks return None
    """
    semaphore = asyncio.Semaphore(max_concurrent)
    completed_count = 0
    total = len(prompts)

    async def _limited(sys_msg: str, prompt: str, schema: Type[BaseModel] | None) -> str | None:
        nonlocal completed_count
        if cancel_event is not None and cancel_event.is_set():
            return None
        async with semaphore:
            if cancel_event is not None and cancel_event.is_set():
                return None
            result = await async_chat_completion(
                system_message=sys_msg,
                prompt=prompt,
                output_schema=schema,
                max_completion_tokens=max_completion_tokens,
                model=model,
                api_key=api_key,
                retry_config=retry_config,
            )
            completed_count += 1
            if on_progress is not None:
                on_progress(completed_count, total)
            return result

    tasks = [_limited(sm, p, os_) for sm, p, os_ in zip(system_messages, prompts, output_schemas)]
    return list(await asyncio.gather(*tasks))


def batch_chat_completions(
    *,
    system_messages: list[str],
    prompts: list[str],
    output_schemas: list[Type[BaseModel] | None],
    max_completion_tokens: int = 4096,
    model: str = "gpt-5-nano",
    max_concurrent: int = 10,
    on_progress: Callable[[int, int], None] | None = None,
    cancel_event: Optional[Any] = None,
    api_key: str | None = None,
) -> list[str | None]:
    """Sync wrapper around async_batch_chat_completions."""
    return asyncio.run(
        async_batch_chat_completions(
            system_messages=system_messages,
            prompts=prompts,
            output_schemas=output_schemas,
            max_completion_tokens=max_completion_tokens,
            model=model,
            max_concurrent=max_concurrent,
            on_progress=on_progress,
            cancel_event=cancel_event,
            api_key=api_key,
        )
    )


def responses_api_call(
    *,
    client: OpenAI,
    payload: dict[str, Any],
) -> Any:
    """
    Execute an OpenAI Responses API call. Returns the raw response object.

    The payload should be built by extraction/verification builders.
    """
    from .retry import call_with_retries
    from .logging import llm_logger, Timer

    model = payload.get("model", "unknown")
    with Timer() as timer:
        result = call_with_retries(client.responses.create, **payload)
    llm_logger.log_call(model=model, latency_ms=timer.elapsed_ms, service="responses_api")
    return result
