"""Stage 4: Assign Papers to Clusters."""
from __future__ import annotations
import json
import logging
import pandas as pd
from .sysmsg import make_cluster_selection_system_message
from .schemas import build_cluster_selection_schema
from .constants import CLUSTER_SELECTION_OUTPUT_TOKENS
from crystallise.llm.client import batch_chat_completions
from crystallise.common.json_utils import parse_json_safe as _parse_json_safe

logger = logging.getLogger(__name__)


def make_paper_cluster_input(row: dict, cluster_list: list[dict], decision_type: str):
    output_schema = build_cluster_selection_schema()
    system_message = make_cluster_selection_system_message(cluster_list, decision_type)
    formatted_text = row.get("rating_reasoning", "")
    return system_message, formatted_text, output_schema


def parse_clusters(x: dict | None, full_cluster_list: list[dict]) -> str | None:
    try:
        chosen = x.get("cluster_list") if isinstance(x, dict) else None
        if not isinstance(chosen, list):
            return None
        valid = [i for i in chosen if i in range(1, len(full_cluster_list) + 1)]
        if not valid:
            return None
        names = [full_cluster_list[i - 1]["cluster_name"] for i in valid]
        return json.dumps(names)
    except (KeyError, IndexError, TypeError) as e:
        logger.warning(f"Cluster parsing failed: {e}")
        return None


def get_paper_clusters(
    df: pd.DataFrame,
    model_name: str,
    cluster_list: list[dict],
    decision_type: str,
    on_progress=None,
    cancel_event=None,
) -> pd.DataFrame:
    """Assign papers to clusters. Returns df with {decision_type}_cluster_list column."""
    cluster_list = sorted(cluster_list, key=lambda x: x.get("cluster_name", ""))
    records = df.to_dict("records")
    inputs = [make_paper_cluster_input(x, cluster_list, decision_type) for x in records]

    system_messages = [x[0] for x in inputs]
    formatted_texts = [x[1] for x in inputs]
    output_schemas = [x[2] for x in inputs]

    responses = batch_chat_completions(
        system_messages=system_messages,
        prompts=formatted_texts,
        output_schemas=output_schemas,
        max_completion_tokens=CLUSTER_SELECTION_OUTPUT_TOKENS,
        model=model_name,
        on_progress=on_progress,
        cancel_event=cancel_event,
    )

    response_jsons = [_parse_json_safe(x) for x in responses]
    cluster_lists = [parse_clusters(x, cluster_list) for x in response_jsons]

    result = df.copy()
    result[f"{decision_type}_cluster_list"] = cluster_lists
    return result
