"""Structured JSON logging for LLM API calls.

Provides a JSON log formatter and request context for correlation IDs.
No external dependencies — uses stdlib logging + json.
"""

from __future__ import annotations

import json
import logging
import time
import uuid
from contextvars import ContextVar
from typing import Any


# Context variable for request correlation ID
_request_id: ContextVar[str] = ContextVar("llm_request_id", default="")


def get_request_id() -> str:
    """Get the current request correlation ID."""
    return _request_id.get()


def set_request_id(request_id: str) -> None:
    """Set the request correlation ID for the current context."""
    _request_id.set(request_id)


def new_request_id() -> str:
    """Generate and set a new request correlation ID."""
    rid = str(uuid.uuid4())[:8]
    _request_id.set(rid)
    return rid


class LLMCallLogger:
    """Structured logger for LLM API call metrics.

    Emits JSON-formatted log entries with standardized fields.
    """

    def __init__(self, logger_name: str = "crystallise.llm"):
        self._logger = logging.getLogger(logger_name)

    def log_call(
        self,
        *,
        model: str,
        latency_ms: float,
        input_tokens: int = 0,
        output_tokens: int = 0,
        attempt: int = 1,
        error_category: str | None = None,
        request_id: str | None = None,
        service: str = "",
        extra: dict[str, Any] | None = None,
    ) -> None:
        """Log a structured LLM API call event."""
        entry: dict[str, Any] = {
            "event": "llm_call",
            "request_id": request_id or get_request_id(),
            "model": model,
            "service": service,
            "latency_ms": round(latency_ms, 1),
            "attempt": attempt,
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
        }
        if error_category:
            entry["error_category"] = error_category
        if extra:
            entry.update(extra)

        if error_category:
            self._logger.error(json.dumps(entry))
        else:
            self._logger.info(json.dumps(entry))

    def log_retry(
        self,
        *,
        model: str,
        attempt: int,
        max_retries: int,
        error_category: str,
        delay_s: float,
        request_id: str | None = None,
    ) -> None:
        """Log a retry event."""
        entry = {
            "event": "llm_retry",
            "request_id": request_id or get_request_id(),
            "model": model,
            "attempt": attempt,
            "max_retries": max_retries,
            "error_category": error_category,
            "delay_s": round(delay_s, 2),
        }
        self._logger.warning(json.dumps(entry))


# Module-level singleton
llm_logger = LLMCallLogger()


class Timer:
    """Simple context manager for measuring elapsed time in milliseconds."""

    def __init__(self):
        self.start_time: float = 0
        self.elapsed_ms: float = 0

    def __enter__(self):
        self.start_time = time.perf_counter()
        return self

    def __exit__(self, *args):
        self.elapsed_ms = (time.perf_counter() - self.start_time) * 1000
