"""Stage 1: Initial Paper Screening (Labelling)."""
from __future__ import annotations
import logging
import numpy as np
import pandas as pd
from .sysmsg import make_labelling_system_message
from .formatting import format_paper_text
from .schemas import IncludeResponse
from .constants import LABELLING_OUTPUT_TOKENS, NUM_LABELLING_REPETITIONS
from crystallise.llm.client import batch_chat_completions
from crystallise.common.json_utils import parse_json_safe as _parse_json_safe

logger = logging.getLogger(__name__)


def _parse_include_rating(x):
    return x["include_rating"] if isinstance(x, dict) and "include_rating" in x else None


def make_paper_labelling_input(row: dict, lit_rev_qs: list, inc_excl_dict: dict):
    system_message = make_labelling_system_message(lit_rev_qs, inc_excl_dict)
    formatted_text = format_paper_text(row)
    return system_message, formatted_text, IncludeResponse


def get_labelling_scores(
    df: pd.DataFrame,
    model_name: str,
    lit_rev_qs: list,
    inc_excl_dict: dict,
    num_repetitions: int = NUM_LABELLING_REPETITIONS,
    on_progress=None,
    cancel_event=None,
) -> pd.DataFrame:
    """Score papers multiple times. Returns df with score_list, min_score, max_score, mean_score."""
    records = df.to_dict("records")
    all_inputs = [make_paper_labelling_input(x, lit_rev_qs, inc_excl_dict) for x in records]
    valid_mask = [inp[1] is not None for inp in all_inputs]
    valid_indices = [i for i, v in enumerate(valid_mask) if v]

    score_list = [[] for _ in range(len(records))]

    for rep_idx in range(num_repetitions):
        if cancel_event is not None and cancel_event.is_set():
            break

        valid_sys = [all_inputs[i][0] for i in valid_indices]
        valid_prompts = [all_inputs[i][1] for i in valid_indices]
        valid_schemas = [all_inputs[i][2] for i in valid_indices]

        def rep_progress(completed, total, _rep=rep_idx, _n=len(valid_indices)):
            if on_progress is not None:
                overall_completed = _rep * _n + completed
                overall_total = num_repetitions * _n
                on_progress(overall_completed, overall_total)

        valid_responses = batch_chat_completions(
            system_messages=valid_sys,
            prompts=valid_prompts,
            output_schemas=valid_schemas,
            max_completion_tokens=LABELLING_OUTPUT_TOKENS,
            model=model_name,
            on_progress=rep_progress,
            cancel_event=cancel_event,
        )

        responses = [None] * len(records)
        for idx, vi in enumerate(valid_indices):
            responses[vi] = valid_responses[idx]

        response_jsons = [_parse_json_safe(x) for x in responses]
        include_ratings = [_parse_include_rating(x) for x in response_jsons]
        for i, rating in enumerate(include_ratings):
            if rating is not None:
                score_list[i].append(rating)

    result = df.copy()
    result["score_list"] = score_list
    result["min_score"] = [min(s) if s else None for s in score_list]
    result["max_score"] = [max(s) if s else None for s in score_list]
    result["mean_score"] = [float(np.mean(s)) if s else None for s in score_list]
    return result
