"""Stage 2: Score Explanation Generation (Reasoning)."""

from __future__ import annotations
import logging
import pandas as pd
from .sysmsg import make_rating_reasoning_system_message
from .formatting import format_paper_text
from .schemas import ReasoningWithEvidenceResponse
from .constants import REASONING_OUTPUT_TOKENS
from crystallise.llm.client import batch_chat_completions
from crystallise.common.json_utils import parse_json_safe as _parse_json_safe

logger = logging.getLogger(__name__)


def _parse_reasoning(x):
    return x["rating_reasoning"] if isinstance(x, dict) and "rating_reasoning" in x else None


def _parse_evidence(raw_evidence):
    """Validate and clean evidence spans from LLM response."""
    if not isinstance(raw_evidence, list):
        return []
    valid = []
    for ev in raw_evidence:
        if isinstance(ev, dict) and ev.get("text") and ev.get("section"):
            valid.append(
                {
                    "text": ev["text"],
                    "section": ev["section"],
                    "criterion": ev.get("criterion", ""),
                    "supports": ev.get("supports", "include"),
                }
            )
    return valid


def make_rating_reasoning_input(row: dict, lit_rev_qs: list, inc_excl_dict: dict):
    system_message = make_rating_reasoning_system_message(lit_rev_qs, inc_excl_dict)
    formatted_text = format_paper_text(row)
    if formatted_text is None:
        return None, None, None
    avg_score = row.get("mean_score", "")
    formatted_text += "\n\n<Average Predicted Score>\n" + str(avg_score)
    return system_message, formatted_text, ReasoningWithEvidenceResponse


def get_rating_reasoning(
    df: pd.DataFrame,
    model_name: str,
    lit_rev_qs: list,
    inc_excl_dict: dict,
    on_progress=None,
    cancel_event=None,
) -> pd.DataFrame:
    """Generate reasoning for each paper's score. Returns df with rating_reasoning and evidence columns."""
    records = df.to_dict("records")
    inputs = [make_rating_reasoning_input(x, lit_rev_qs, inc_excl_dict) for x in records]

    valid_mask = [inp[0] is not None for inp in inputs]
    valid_indices = [i for i, v in enumerate(valid_mask) if v]

    valid_sys = [inputs[i][0] for i in valid_indices]
    valid_prompts = [inputs[i][1] for i in valid_indices]
    valid_schemas = [inputs[i][2] for i in valid_indices]

    valid_responses = batch_chat_completions(
        system_messages=valid_sys,
        prompts=valid_prompts,
        output_schemas=valid_schemas,
        max_completion_tokens=REASONING_OUTPUT_TOKENS,
        model=model_name,
        on_progress=on_progress,
        cancel_event=cancel_event,
    )

    responses = [None] * len(records)
    for idx, vi in enumerate(valid_indices):
        responses[vi] = valid_responses[idx]

    response_jsons = [_parse_json_safe(x) for x in responses]
    rating_reasonings = [_parse_reasoning(x) for x in response_jsons]
    evidence_lists = [_parse_evidence(x.get("evidence")) if isinstance(x, dict) else [] for x in response_jsons]

    result = df.copy()
    result["rating_reasoning"] = rating_reasonings
    result["evidence"] = evidence_lists
    return result
