"""API key authentication middleware for FastAPI (NetReady MVP)."""

from __future__ import annotations

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse

_PUBLIC_PREFIXES = (
    "/health",
    "/docs",
    "/redoc",
    "/openapi.json",
)


class APIKeyAuthMiddleware(BaseHTTPMiddleware):
    """Validate API key on every request (except public paths)."""

    async def dispatch(self, request: Request, call_next):
        path = request.url.path
        normalized = path.removeprefix("/v1")

        # Public paths — no auth required
        if normalized.startswith(_PUBLIC_PREFIXES):
            return await call_next(request)

        # Allow OPTIONS (CORS preflight)
        if request.method == "OPTIONS":
            return await call_next(request)

        # Extract API key from header
        api_key = request.headers.get("x-api-key") or _extract_bearer(request.headers.get("authorization", ""))

        if not api_key:
            return JSONResponse(
                status_code=401,
                content={"detail": "Missing API key. Use X-API-Key or Authorization: Bearer <key> header."},
            )

        # Validate against configured keys
        from crystallise.config.settings import get_settings

        settings = get_settings()
        valid_keys = {k.strip() for k in settings.api_keys.split(",") if k.strip()}

        if not valid_keys:
            # No keys configured — accept any non-empty key (dev mode)
            pass
        elif api_key not in valid_keys:
            return JSONResponse(
                status_code=401,
                content={"detail": "Invalid API key."},
            )

        # Attach client identity to request state
        request.state.user = {"email": "api-client", "name": "API Client"}
        request.state.api_key = api_key
        return await call_next(request)


def _extract_bearer(header: str) -> str | None:
    """Extract token from 'Bearer <token>' header value."""
    if header.lower().startswith("bearer "):
        return header[7:].strip() or None
    return None
