"""Tag grouping service for the AutoIndexer."""

from __future__ import annotations

import json
import logging
from typing import Any

from openai import OpenAI

from crystallise.llm.cost import tally_usage
from crystallise.openai_resources.vector_stores import normalize_chat_completion_kwargs
from crystallise.prompts.indexer import GROUPING_SYSTEM_PROMPT as SYSTEM_PROMPT

logger = logging.getLogger(__name__)

GROUPING_MODEL = "gpt-4.1"


def group_tags(*, client: OpenAI, req: Any) -> Any:
    """Call LLM to suggest groupings for extracted tag values."""
    from api.schemas.indexer import GroupTagsResponse, TagGroup

    user_parts = [
        f"Field: {req.field_name}",
        f"Values to group ({len(req.values)} unique):",
        json.dumps(req.values, indent=2),
    ]

    if req.project_context:
        if req.project_context.description:
            user_parts.append(f"\nProject context: {req.project_context.description}")
        if req.project_context.research_questions:
            user_parts.append(f"Research questions: {'; '.join(req.project_context.research_questions)}")

    if req.num_groups_hint:
        user_parts.append(f"\nAim for approximately {req.num_groups_hint} groups.")

    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": "\n".join(user_parts)},
    ]

    completion_kwargs = normalize_chat_completion_kwargs(
        {
            "model": GROUPING_MODEL,
            "messages": messages,
            "max_completion_tokens": 4096,
            "temperature": 0.3,
        }
    )
    resp = client.chat.completions.create(**completion_kwargs)
    content = resp.choices[0].message.content or "{}"
    usage_data = tally_usage(GROUPING_MODEL, [resp.usage]) if resp.usage else None

    # Parse response using shared utility
    from crystallise.common.json_utils import parse_llm_json, LLMParseError

    try:
        data = parse_llm_json(content)
    except LLMParseError:
        logger.warning("Failed to parse grouping response: %s", content[:200])
        return GroupTagsResponse(groups=[], usage=usage_data)

    groups_data = data.get("groups", []) if isinstance(data, dict) else data
    if not isinstance(groups_data, list):
        return GroupTagsResponse(groups=[], usage=usage_data)

    groups = []
    for item in groups_data:
        if not isinstance(item, dict):
            continue
        name = item.get("name", "")
        values = item.get("values", [])
        if not name or not values:
            continue
        # Ensure values are strings
        values = [str(v) for v in values if v]
        groups.append(
            TagGroup(
                name=name,
                values=values,
                rationale=item.get("rationale", ""),
            )
        )

    return GroupTagsResponse(groups=groups, usage=usage_data)
