"""
AI Service for Criteria generation, PICO extraction, refinement, and consolidation.

Ported from AI_Screening_UI criteria_workspace/ai_service.py.
Stripped of Panel/DB/diagnostics coupling. Uses crystallise.llm.client for LLM calls.
All configuration (model, api_key) is passed as explicit parameters.
"""

from __future__ import annotations

import json
import logging
import re
from typing import Optional

from .models import (
    ConsolidationProposal,
    ConsolidationResult,
    CriterionSource,
    CriterionType,
    DuplicateGroup,
    ExclusionCriterion,
    PICO_CATEGORIES,
)
from .prompts import (
    conflict_refinement_system_prompt,
    conflict_refinement_user_prompt,
    consolidation_system_prompt,
    consolidation_user_prompt,
    context_refinement_system_prompt,
    context_refinement_user_prompt,
    exclusion_generation_system_prompt,
    exclusion_generation_user_prompt,
    inclusion_generation_system_prompt,
    inclusion_generation_user_prompt,
    pico_extraction_system_prompt,
    pico_extraction_user_prompt,
    question_analysis_system_prompt,
    question_analysis_user_prompt,
    refinement_user_prompt,
)

logger = logging.getLogger(__name__)

# Default model for criteria-level work (needs stronger reasoning than screening)
DEFAULT_CRITERIA_MODEL = "gpt-4.1"


# ---------------------------------------------------------------------------
# Response parsing helpers
# ---------------------------------------------------------------------------


def _parse_criteria_response(
    response: str,
    project_id: int,
    existing_criteria: Optional[list[ExclusionCriterion]] = None,
) -> list[ExclusionCriterion]:
    """Parse JSON response into ExclusionCriterion objects with validation."""
    valid_categories = set(PICO_CATEGORIES)

    # Strip markdown code fences if present
    text = response.strip()
    if text.startswith("```"):
        lines = text.split("\n")
        lines = [line for line in lines if not line.strip().startswith("```")]
        text = "\n".join(lines)

    try:
        data = json.loads(text)
    except json.JSONDecodeError:
        json_match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", text)
        if json_match:
            try:
                data = json.loads(json_match.group(1))
            except json.JSONDecodeError:
                logger.warning("Failed to parse criteria response as JSON")
                return []
        else:
            logger.warning("No JSON found in criteria response")
            return []

    # Handle {"criteria": [...]} wrapper
    if isinstance(data, dict):
        if "criteria" in data:
            data = data["criteria"]
        else:
            data = list(data.values())

    if not isinstance(data, list):
        logger.warning("Response is not a JSON array or object with criteria key")
        return []

    existing_texts: set[str] = set()
    if existing_criteria:
        existing_texts = {c.text.lower() for c in existing_criteria}

    criteria = []
    for item in data:
        if not isinstance(item, dict):
            continue

        # Validate category
        cat = item.get("category", "Other")
        if cat not in valid_categories:
            logger.debug("Invalid category '%s' mapped to 'Other'", cat)
            cat = "Other"

        # Clean text
        text_val = item.get("text", "").strip()
        text_val = re.sub(r"^\*+\s*", "", text_val)
        text_val = re.sub(r"\s*\*+$", "", text_val)
        text_val = re.sub(r"^-\s*", "", text_val)
        text_val = re.sub(r"^#+\s*", "", text_val)
        text_val = text_val.strip()

        if not text_val or len(text_val) < 3:
            continue

        # Skip duplicates
        if text_val.lower() in existing_texts:
            continue

        # Criterion type
        ctype_str = item.get("criterion_type", "exclude")
        if ctype_str not in ("include", "exclude"):
            ctype_str = "exclude"

        # Title/abstract assessability
        ta_assessable = item.get("title_abstract_assessable", True)
        if not isinstance(ta_assessable, bool):
            ta_assessable = str(ta_assessable).lower() in ("true", "1", "yes")

        # RQ links
        rq_links = item.get("research_question_links", [])
        if not isinstance(rq_links, list):
            rq_links = []

        criteria.append(
            ExclusionCriterion(
                id=0,
                project_id=project_id,
                category=cat,
                text=text_val,
                description=item.get("description", ""),
                source=CriterionSource.ai_generated,
                is_active=False,
                criterion_type=CriterionType(ctype_str),
                ai_confidence=min(1.0, max(0.0, float(item.get("confidence", 0.8)))),
                ai_rationale=item.get("rationale", ""),
                title_abstract_assessable=ta_assessable,
                research_question_links=rq_links,
            )
        )
        existing_texts.add(text_val.lower())

    logger.info("Parsed %d criteria from LLM response", len(criteria))
    return criteria


def _parse_refinement_response(content: str) -> list[dict]:
    """Parse LLM refinement response into structured suggestions."""
    text = content.strip()
    if text.startswith("```"):
        lines = text.split("\n")
        lines = [line for line in lines if not line.strip().startswith("```")]
        text = "\n".join(lines)

    try:
        data = json.loads(text)
        if isinstance(data, list):
            return data
    except (json.JSONDecodeError, ValueError):
        pass

    # Fallback: parse text-based suggestions
    suggestions: list[dict] = []
    current: dict = {}
    for line in content.split("\n"):
        line = line.strip()
        if not line:
            if current:
                suggestions.append(current)
                current = {}
            continue
        lower = line.lower()
        if lower.startswith("- action:"):
            current["action"] = line.split(":", 1)[1].strip()
        elif lower.startswith("- category:"):
            current["category"] = line.split(":", 1)[1].strip()
        elif lower.startswith("- suggested:"):
            current["text"] = line.split(":", 1)[1].strip()
        elif lower.startswith("- rationale:"):
            current["rationale"] = line.split(":", 1)[1].strip()

    if current:
        suggestions.append(current)

    return suggestions


# ---------------------------------------------------------------------------
# Core AI Service
# ---------------------------------------------------------------------------


class CriteriaAIService:
    """
    AI operations for eligibility criteria.

    All methods accept model/api_key as explicit parameters and use
    crystallise.llm.client for LLM calls. No DB or UI dependencies.
    """

    def __init__(self, project_id: int = 0):
        self.project_id = project_id

    # ================================================================
    # Generate exclusion criteria
    # ================================================================

    async def generate_exclusion_criteria(
        self,
        project_description: Optional[str] = None,
        research_questions: Optional[list[str]] = None,
        additional_notes: Optional[str] = None,
        existing_criteria: Optional[list[ExclusionCriterion]] = None,
        model: str = DEFAULT_CRITERIA_MODEL,
        api_key: Optional[str] = None,
    ) -> list[ExclusionCriterion]:
        """
        Generate exclusion criteria based on project context.

        Args:
            project_description: Project description text
            research_questions: List of research questions
            additional_notes: Additional context for generation
            existing_criteria: Existing criteria (for deduplication)
            model: OpenAI model name
            api_key: OpenAI API key (falls back to env var)

        Returns:
            List of generated ExclusionCriterion objects
        """
        from crystallise.llm.client import async_chat_completion

        system_prompt = exclusion_generation_system_prompt()
        user_prompt = exclusion_generation_user_prompt(project_description, research_questions, additional_notes)

        logger.info(
            "Generating exclusion criteria (model=%s, prompt_len=%d)",
            model,
            len(user_prompt),
        )

        content = await async_chat_completion(
            system_message=system_prompt,
            prompt=user_prompt,
            model=model,
            api_key=api_key,
            max_completion_tokens=4096,
        )

        if not content:
            raise RuntimeError("LLM returned empty content for exclusion criteria generation")

        return _parse_criteria_response(content, self.project_id, existing_criteria)

    # ================================================================
    # Generate inclusion criteria
    # ================================================================

    async def generate_inclusion_criteria(
        self,
        project_description: Optional[str] = None,
        research_questions: Optional[list[str]] = None,
        additional_notes: Optional[str] = None,
        existing_criteria: Optional[list[ExclusionCriterion]] = None,
        model: str = DEFAULT_CRITERIA_MODEL,
        api_key: Optional[str] = None,
    ) -> list[ExclusionCriterion]:
        """
        Generate inclusion criteria based on project context.

        Returns:
            List of generated ExclusionCriterion objects with criterion_type='include'
        """
        from crystallise.llm.client import async_chat_completion

        system_prompt = inclusion_generation_system_prompt()
        user_prompt = inclusion_generation_user_prompt(project_description, research_questions, additional_notes)

        logger.info(
            "Generating inclusion criteria (model=%s, prompt_len=%d)",
            model,
            len(user_prompt),
        )

        content = await async_chat_completion(
            system_message=system_prompt,
            prompt=user_prompt,
            model=model,
            api_key=api_key,
            max_completion_tokens=4096,
        )

        if not content:
            raise RuntimeError("LLM returned empty content for inclusion criteria generation")

        criteria = _parse_criteria_response(content, self.project_id, existing_criteria)
        # Ensure all are marked as inclusion
        for c in criteria:
            c.criterion_type = CriterionType.include
        return criteria

    # ================================================================
    # PICO Extraction
    # ================================================================

    async def extract_pico(
        self,
        project_description: Optional[str] = None,
        research_questions: Optional[list[str]] = None,
        existing_criteria: Optional[list[dict]] = None,
        model: str = DEFAULT_CRITERIA_MODEL,
        api_key: Optional[str] = None,
    ) -> dict:
        """
        Extract PICO elements from project context.

        Returns:
            Dict with pico_extraction, gap_flags, and contraindications
        """
        from crystallise.llm.client import async_chat_completion

        system_prompt = pico_extraction_system_prompt()
        user_prompt = pico_extraction_user_prompt(project_description, research_questions, existing_criteria)

        logger.info(
            "Extracting PICO elements (model=%s, prompt_len=%d)",
            model,
            len(user_prompt),
        )

        content = await async_chat_completion(
            system_message=system_prompt,
            prompt=user_prompt,
            model=model,
            api_key=api_key,
            temperature=0.1,
            max_completion_tokens=4096,
        )

        if not content:
            raise RuntimeError("LLM returned empty content for PICO extraction")

        return json.loads(content)

    # ================================================================
    # Refine from reconciliation patterns
    # ================================================================

    async def refine_criteria(
        self,
        reconciliation_patterns: dict,
        current_criteria: list[dict],
        inclusion_criteria: list[dict],
        project_description: str = "",
        model: str = DEFAULT_CRITERIA_MODEL,
        api_key: Optional[str] = None,
    ) -> list[dict]:
        """
        Generate criteria refinement suggestions from reconciliation patterns.

        Returns:
            List of suggestion dicts with keys: action, category, text, rationale
        """
        from crystallise.llm.client import async_chat_completion

        user_prompt = refinement_user_prompt(
            reconciliation_patterns,
            current_criteria,
            inclusion_criteria,
            project_description,
        )

        logger.info("Generating refinement suggestions (model=%s)", model)

        content = await async_chat_completion(
            system_message="You are a systematic review methodology expert. Analyze screening conflicts and suggest criteria refinements. Return JSON.",
            prompt=user_prompt,
            model=model,
            api_key=api_key,
            max_completion_tokens=4096,
        )

        if not content:
            return []

        return _parse_refinement_response(content)

    # ================================================================
    # Refine from conflict papers
    # ================================================================

    async def reconcile_conflicts(
        self,
        conflict_papers: list[dict],
        active_criteria: Optional[list[ExclusionCriterion]] = None,
        project_description: str = "",
        research_questions: Optional[list[str]] = None,
        model: str = DEFAULT_CRITERIA_MODEL,
        api_key: Optional[str] = None,
    ) -> list[ExclusionCriterion]:
        """
        Analyze human/AI decision conflicts and generate refined eligibility criteria.

        Args:
            conflict_papers: Papers where human and AI decisions disagree
            active_criteria: Current active criteria for context
            project_description: Project description
            research_questions: List of research questions
            model: OpenAI model name
            api_key: OpenAI API key

        Returns:
            List of ExclusionCriterion objects to improve screening accuracy
        """
        if not conflict_papers:
            return []

        from crystallise.llm.client import async_chat_completion

        system_prompt = conflict_refinement_system_prompt()
        user_prompt = conflict_refinement_user_prompt(
            conflict_papers, active_criteria, project_description, research_questions
        )

        logger.info(
            "Reconciling conflicts (model=%s, papers=%d)",
            model,
            len(conflict_papers),
        )

        content = await async_chat_completion(
            system_message=system_prompt,
            prompt=user_prompt,
            model=model,
            api_key=api_key,
            max_completion_tokens=4096,
        )

        if not content:
            return []

        return _parse_criteria_response(content, self.project_id, active_criteria)

    # ================================================================
    # Consolidate & Deduplicate
    # ================================================================

    async def consolidate_criteria(
        self,
        criteria: list[ExclusionCriterion],
        project_description: Optional[str] = None,
        research_questions: Optional[list[str]] = None,
        model: str = DEFAULT_CRITERIA_MODEL,
        api_key: Optional[str] = None,
    ) -> ConsolidationResult:
        """
        Detect duplicates and propose consolidations for eligibility criteria.

        Args:
            criteria: List of active criteria to analyze
            project_description: Project context
            research_questions: List of research questions
            model: OpenAI model name
            api_key: OpenAI API key

        Returns:
            ConsolidationResult with duplicate_groups, proposals, and warnings
        """
        if not criteria or len(criteria) < 2:
            return ConsolidationResult(warnings=["Need at least 2 criteria to detect duplicates."])

        from crystallise.llm.client import async_chat_completion

        system_prompt = consolidation_system_prompt()
        user_prompt = consolidation_user_prompt(criteria, project_description, research_questions)

        logger.info(
            "Detecting duplicates/consolidation (model=%s, criteria=%d)",
            model,
            len(criteria),
        )

        content = await async_chat_completion(
            system_message=system_prompt,
            prompt=user_prompt,
            model=model,
            api_key=api_key,
            temperature=0.0,
            max_completion_tokens=4096,
        )

        if not content:
            return ConsolidationResult(warnings=["Empty response from LLM"])

        return self._parse_consolidation_response(content, criteria)

    def _parse_consolidation_response(
        self,
        content: str,
        criteria: list[ExclusionCriterion],
    ) -> ConsolidationResult:
        """Parse LLM response into ConsolidationResult with safety guards."""
        id_to_criterion = {c.id: c for c in criteria}

        try:
            data = json.loads(content)
        except json.JSONDecodeError as e:
            return ConsolidationResult(warnings=[f"Failed to parse LLM response: {e}"])

        # Parse duplicate groups
        duplicate_groups = []
        for g in data.get("duplicate_groups", []):
            group = DuplicateGroup(**g)
            group.criteria_texts = [id_to_criterion[cid].text for cid in group.criterion_ids if cid in id_to_criterion]
            duplicate_groups.append(group)

        # Parse consolidation proposals
        consolidation_proposals = []
        for p in data.get("consolidation_proposals", []):
            proposal = ConsolidationProposal(**p)
            proposal.criteria_texts = [
                id_to_criterion[cid].text for cid in proposal.criterion_ids if cid in id_to_criterion
            ]
            consolidation_proposals.append(proposal)

        # Safety guards
        warnings = list(data.get("warnings", []))
        filtered_groups = []
        for g in duplicate_groups:
            if g.ai_confidence < 0.75:
                warnings.append(
                    f"Rejected low-confidence duplicate group ({g.confidence_percentage}%): IDs {g.criterion_ids}"
                )
            else:
                filtered_groups.append(g)

        filtered_proposals = []
        for p in consolidation_proposals:
            word_count = len(p.proposed_merged_criterion.split())
            if word_count > 10:
                warnings.append(
                    f"Rejected proposal - label too long ({word_count} words): '{p.proposed_merged_criterion[:50]}...'"
                )
            elif p.ai_confidence < 0.75:
                warnings.append(f"Rejected low-confidence proposal ({p.confidence_percentage}%): IDs {p.criterion_ids}")
            else:
                filtered_proposals.append(p)

        return ConsolidationResult(
            duplicate_groups=filtered_groups,
            consolidation_proposals=filtered_proposals,
            warnings=warnings,
        )


# ---------------------------------------------------------------------------
# Standalone wrapper functions for router imports
# ---------------------------------------------------------------------------


def _criterion_to_dict(c: ExclusionCriterion) -> dict:
    """Convert ExclusionCriterion to plain dict for router consumption."""
    return {
        "category": c.category,
        "text": c.text,
        "description": c.description,
        "criterion_type": c.criterion_type.value if hasattr(c.criterion_type, "value") else str(c.criterion_type),
        "confidence": c.ai_confidence,
        "rationale": c.ai_rationale,
        "title_abstract_assessable": c.title_abstract_assessable,
    }


async def generate_criteria(
    project_description: str = "",
    research_questions: list[str] | None = None,
    additional_notes: str = "",
    existing_criteria: list[dict] | None = None,
    criterion_type: str = "exclude",
    model: str = DEFAULT_CRITERIA_MODEL,
    api_key: Optional[str] = None,
) -> list[dict]:
    """Standalone wrapper: generate criteria via CriteriaAIService."""
    svc = CriteriaAIService()
    if criterion_type == "include":
        results = await svc.generate_inclusion_criteria(
            project_description=project_description,
            research_questions=research_questions,
            additional_notes=additional_notes,
            model=model,
            api_key=api_key,
        )
    else:
        results = await svc.generate_exclusion_criteria(
            project_description=project_description,
            research_questions=research_questions,
            additional_notes=additional_notes,
            model=model,
            api_key=api_key,
        )
    return [_criterion_to_dict(c) for c in results]


async def extract_pico(
    project_description: str = "",
    research_questions: list[str] | None = None,
    model: str = DEFAULT_CRITERIA_MODEL,
    api_key: Optional[str] = None,
) -> dict:
    """Standalone wrapper: extract PICO via CriteriaAIService."""
    svc = CriteriaAIService()
    return await svc.extract_pico(
        project_description=project_description,
        research_questions=research_questions,
        model=model,
        api_key=api_key,
    )


async def refine_criteria(
    current_criteria: list[dict] | None = None,
    conflicts: list[dict] | None = None,
    project_description: str = "",
    model: str = DEFAULT_CRITERIA_MODEL,
    api_key: Optional[str] = None,
) -> list[dict]:
    """Standalone wrapper: refine criteria via CriteriaAIService."""
    svc = CriteriaAIService()
    results = await svc.refine_criteria(
        reconciliation_patterns={"conflicts": conflicts or []},
        current_criteria=current_criteria or [],
        inclusion_criteria=[],
        project_description=project_description,
        model=model,
        api_key=api_key,
    )
    return results


async def refine_context(
    description: str = "",
    research_questions: list[str] | None = None,
    model: str = DEFAULT_CRITERIA_MODEL,
    api_key: Optional[str] = None,
) -> dict:
    """Refine project description and research questions for better screening."""
    from crystallise.llm.client import async_chat_completion

    system_prompt = context_refinement_system_prompt()
    user_prompt = context_refinement_user_prompt(
        description,
        research_questions or [],
    )

    logger.info("Refining project context (model=%s)", model)

    content = await async_chat_completion(
        system_message=system_prompt,
        prompt=user_prompt,
        model=model,
        api_key=api_key,
        max_completion_tokens=4096,
    )

    if not content:
        raise RuntimeError("LLM returned empty content for context refinement")

    return json.loads(content)


async def consolidate_criteria(
    criteria: list[ExclusionCriterion],
    project_description: Optional[str] = None,
    research_questions: Optional[list[str]] = None,
    model: str = DEFAULT_CRITERIA_MODEL,
    api_key: Optional[str] = None,
) -> ConsolidationResult:
    """Module-level wrapper for CriteriaAIService.consolidate_criteria."""
    service = CriteriaAIService()
    return await service.consolidate_criteria(
        criteria=criteria,
        project_description=project_description,
        research_questions=research_questions,
        model=model,
        api_key=api_key,
    )


async def analyze_research_question(
    research_question: str,
    model: str = "gpt-5-mini",
    api_key: Optional[str] = None,
) -> dict:
    """Analyse a single research question for search-readiness (PICOS framing).

    Returns a dict with keys: status ("ready" | "could_improve"), missing_elements
    (list of strings), suggestion (string). The router validates against the
    Pydantic response schema.
    """
    from crystallise.llm.client import async_chat_completion

    system_prompt = question_analysis_system_prompt()
    user_prompt = question_analysis_user_prompt(research_question)

    logger.info("Analysing research question (model=%s, len=%d)", model, len(research_question))

    content = await async_chat_completion(
        system_message=system_prompt,
        prompt=user_prompt,
        model=model,
        api_key=api_key,
        max_completion_tokens=4096,
    )

    if not content:
        raise RuntimeError("LLM returned empty content for question analysis")

    return json.loads(content)
