"""
Token tallying and cost estimation.

Merged from:
- CrystalliseAppsToolbox apps/common/batch.py (estimate_cost, tally_usage, DEFAULT_PRICING_PER_1M)
- AI_Screening_UI metrics tracking patterns
"""

from __future__ import annotations

from typing import Any

# Default pricing (USD per 1M tokens).
DEFAULT_PRICING_PER_1M: dict[str, dict[str, float]] = {
    "gpt-5-nano": {"input": 0.05, "cached_input": 0.01, "output": 0.40},
    "gpt-5-mini": {"input": 0.20, "cached_input": 0.025, "output": 2.00},
    "gpt-5.4-nano": {"input": 0.20, "output": 1.25},
    "gpt-5.4-mini": {"input": 0.75, "output": 4.50},
    "gpt-4.1": {"input": 2.00, "cached_input": 0.50, "output": 8.00},
}


def estimate_cost(
    model: str,
    input_tokens: int,
    output_tokens: int,
    pricing_table: dict[str, dict[str, float]] | None = None,
) -> float:
    """Estimate cost in USD for a given model and token counts."""
    table = pricing_table or DEFAULT_PRICING_PER_1M
    p = table.get(model, {"input": 0.0, "output": 0.0})
    return (input_tokens / 1_000_000.0) * p.get("input", 0.0) + (output_tokens / 1_000_000.0) * p.get("output", 0.0)


def tally_usage(
    model: str,
    usage_list: list[dict[str, Any] | Any],
    pricing_table: dict[str, dict[str, float]] | None = None,
) -> dict[str, Any]:
    """
    Aggregate a list of usage blobs (dicts or SDK objects) into totals + estimated cost.

    Returns dict with keys: input_tokens, output_tokens, total_tokens, estimated_cost_usd.
    """
    total_in = 0
    total_out = 0
    total = 0
    for u in usage_list:
        in_tok = int(getattr(u, "input_tokens", 0) if not isinstance(u, dict) else u.get("input_tokens", 0))
        out_tok = int(getattr(u, "output_tokens", 0) if not isinstance(u, dict) else u.get("output_tokens", 0))
        tot_tok = int(
            getattr(u, "total_tokens", in_tok + out_tok)
            if not isinstance(u, dict)
            else u.get("total_tokens", in_tok + out_tok)
        )
        total_in += in_tok
        total_out += out_tok
        total += tot_tok

    cost = estimate_cost(model, total_in, total_out, pricing_table)
    return {
        "input_tokens": total_in,
        "output_tokens": total_out,
        "total_tokens": total,
        "estimated_cost_usd": cost,
    }
