"""Data formatting utilities for LLM prompts."""
import random
from .constants import INPUT_COLS


def format_paper_text(row: dict) -> str | None:
    """Format paper metadata into XML-tagged text for LLM input."""
    lower_input_cols = [x.lower() for x in INPUT_COLS]
    row = {k.lower(): v for k, v in row.items()}
    formatted_text = ""
    for col in lower_input_cols:
        value = row.get(col, "")
        if not isinstance(value, str):
            continue
        value = value.strip()
        if len(value) < 1:
            continue
        formatted_text += f"<{col}>\n{value}\n\n"
    if len(formatted_text) < 1:
        return None
    return formatted_text


def format_inc_exc(inc_exc: dict) -> str:
    output_str = ""
    shuffled_keys = random.sample(list(inc_exc.keys()), len(inc_exc.keys()))
    for key in shuffled_keys:
        inc_exc_dict = inc_exc[key]
        output_str += f"# {key}:\n"
        if "include" in inc_exc_dict:
            output_str += "\nInclude:\n"
            for item in inc_exc_dict["include"]:
                output_str += f"- {str(item)}\n"
        if "exclude" in inc_exc_dict:
            output_str += "\nExclude:\n"
            for item in inc_exc_dict["exclude"]:
                output_str += f"- {str(item)}\n"
        output_str += "\n\n"
    return output_str.strip()


def format_questions(question_list: list) -> str:
    if not question_list:
        return "(No research questions provided)"
    question_list = random.sample(list(question_list), len(question_list))
    return "\n\n".join([f"{i+1}. {x}" for i, x in enumerate(question_list)])


def format_clusters(cluster_list: list) -> str:
    text = ""
    for i, cluster in enumerate(cluster_list):
        text += f"Cluster number:\n{i+1}\n"
        text += "Cluster name:\n"
        name = cluster.get("cluster_name", "")
        text += str(name) if not isinstance(name, str) else name
        text += "\nCluster description:\n"
        desc = cluster.get("cluster_description", "")
        text += str(desc) if not isinstance(desc, str) else desc
        text += "\n\n\n\n"
    return text
