"""
Mock AI Service for development and testing.

Simulates AI screening behavior with deterministic outputs
for testing and development without API costs.
"""

from __future__ import annotations
import asyncio
import logging
import random
from typing import Any, Callable, Dict, List, Optional, Tuple
import pandas as pd

logger = logging.getLogger(__name__)


class MockAIService:
    """
    Simulates AI screening behavior for development.

    Provides deterministic scoring based on content features to
    enable testing without real LLM calls.
    """

    # Simulated processing delays (seconds)
    DELAY_PER_PAPER = 0.0
    TOTAL_TARGET_TIME = 0.5

    # Keywords that influence scoring
    INCLUDE_KEYWORDS = [
        "randomized",
        "clinical trial",
        "efficacy",
        "outcome",
        "treatment",
        "patient",
        "therapy",
        "results",
        "study",
        "phase iii",
        "phase ii",
        "controlled",
        "intervention",
    ]

    EXCLUDE_KEYWORDS = [
        "review",
        "meta-analysis",
        "protocol",
        "editorial",
        "commentary",
        "letter",
        "opinion",
        "case report",
        "guideline",
        "recommendation",
    ]

    # Cluster templates - matches real AI output format from clustering.py
    CLUSTER_TEMPLATES = [
        {
            "cluster_name": "Treatment Efficacy Studies",
            "cluster_type": "include",
            "cluster_description": "Papers evaluating treatment outcomes and efficacy measures",
            "related_criteria": "Intervention",
        },
        {
            "cluster_name": "Methodology & Protocol Only",
            "cluster_type": "exclude",
            "cluster_description": "Papers focused on study design and methodology without clinical results",
            "related_criteria": "Study Design",
        },
        {
            "cluster_name": "Adult Patient Population",
            "cluster_type": "include",
            "cluster_description": "Papers focused on adult patient demographics with defined inclusion criteria",
            "related_criteria": "Population",
        },
        {
            "cluster_name": "Biomarker Analysis Studies",
            "cluster_type": "include",
            "cluster_description": "Papers with molecular or biomarker focus and genetic data",
            "related_criteria": "Outcome",
        },
        {
            "cluster_name": "Review & Meta-Analysis (No Primary Data)",
            "cluster_type": "exclude",
            "cluster_description": "Secondary research synthesizing existing studies without new data",
            "related_criteria": "Study Design",
        },
        {
            "cluster_name": "Safety & Adverse Events Reporting",
            "cluster_type": "include",
            "cluster_description": "Papers focusing on treatment safety profiles and toxicity assessment",
            "related_criteria": "Outcome",
        },
        {
            "cluster_name": "Quality of Life Outcomes",
            "cluster_type": "include",
            "cluster_description": "Papers measuring patient-reported outcomes and functional status",
            "related_criteria": "Outcome",
        },
        {
            "cluster_name": "Pure Economic Analysis (No Clinical Data)",
            "cluster_type": "exclude",
            "cluster_description": "Cost-effectiveness and health economics studies without clinical outcomes",
            "related_criteria": "Other",
        },
    ]

    def __init__(self, delay_per_paper: float | None = None):
        self.delay_per_paper = delay_per_paper if delay_per_paper is not None else self.DELAY_PER_PAPER
        self._cancelled = False

    @staticmethod
    def calculate_mock_score(abstract: str, title: str = "") -> float:
        """
        Generate deterministic score based on content features.
        Designed to produce ~30-40% excludes for realistic testing.

        Returns:
            Score between 1.0 and 5.0
        """
        if not abstract:
            return 1.5

        text = f"{title} {abstract}".lower()

        # Use multiple hash factors for better distribution
        text_hash = hash(abstract[:50] if len(abstract) > 50 else abstract)
        title_hash = hash(title[:30] if len(title) > 30 else title) if title else 0
        combined_hash = abs(text_hash + title_hash)

        # Calculate base score - aim for ~40% below 2.5 threshold
        bucket = combined_hash % 100

        if bucket < 20:
            base_score = 1.0 + (bucket / 20) * 0.8
        elif bucket < 40:
            base_score = 1.8 + ((bucket - 20) / 20) * 0.6
        elif bucket < 60:
            base_score = 2.4 + ((bucket - 40) / 20) * 0.4
        elif bucket < 80:
            base_score = 2.8 + ((bucket - 60) / 20) * 0.7
        else:
            base_score = 3.5 + ((bucket - 80) / 20) * 1.0

        # Small adjustments based on keywords (max +/- 0.3)
        include_count = sum(1 for kw in MockAIService.INCLUDE_KEYWORDS if kw in text)
        exclude_count = sum(1 for kw in MockAIService.EXCLUDE_KEYWORDS if kw in text)

        base_score += min(include_count * 0.05, 0.3)
        base_score -= min(exclude_count * 0.1, 0.3)

        return max(1.0, min(5.0, base_score))

    @staticmethod
    def generate_mock_evidence(
        title: str,
        abstract: str,
        score: float,
    ) -> list[dict[str, str]]:
        """Generate mock evidence spans by extracting real substrings from title/abstract."""
        evidence = []
        criteria_pool = ["Population", "Study Design", "Outcome", "Intervention"]

        # Extract chunks from abstract
        words = abstract.split() if abstract else []
        if len(words) >= 6:
            # Pick 2-4 spans based on score deterministically (local RNG for thread safety)
            rng = random.Random(hash(abstract[:30]) + int(score * 10))
            num_spans = min(4, max(2, int(score)))
            for i in range(num_spans):
                start = rng.randint(0, max(0, len(words) - 6))
                span_len = rng.randint(4, min(8, len(words) - start))
                text = " ".join(words[start : start + span_len])
                evidence.append(
                    {
                        "text": text,
                        "section": "abstract",
                        "criterion": criteria_pool[i % len(criteria_pool)],
                        "supports": "include" if score >= 3.0 else "exclude",
                    }
                )

        # Also grab a title span if title is long enough
        title_words = title.split() if title else []
        if len(title_words) >= 4:
            text = " ".join(title_words[: min(6, len(title_words))])
            evidence.append(
                {
                    "text": text,
                    "section": "title",
                    "criterion": criteria_pool[0],
                    "supports": "include" if score >= 3.0 else "exclude",
                }
            )

        return evidence

    @staticmethod
    def generate_mock_reasoning(title: str, abstract: str, score: float) -> str:
        """Generate mock reasoning based on score."""
        title_preview = title[:50] + "..." if len(title) > 50 else title

        text = f"{title} {abstract}".lower()
        found_include = [kw for kw in MockAIService.INCLUDE_KEYWORDS if kw in text]
        found_exclude = [kw for kw in MockAIService.EXCLUDE_KEYWORDS if kw in text]

        if score >= 4.0:
            keywords_note = f" Keywords suggesting inclusion: {', '.join(found_include[:3])}." if found_include else ""
            return f"Strong candidate for inclusion. '{title_preview}' directly addresses research criteria with clear methodology.{keywords_note}"
        elif score >= 3.0:
            return f"Moderate relevance. '{title_preview}' partially meets inclusion criteria but requires human review for final determination."
        elif score >= 2.0:
            keywords_note = f" Potential exclusion indicators: {', '.join(found_exclude[:2])}." if found_exclude else ""
            return (
                f"Weak relevance. '{title_preview}' appears tangentially related to the research topic.{keywords_note}"
            )
        else:
            keywords_note = f" Exclusion keywords found: {', '.join(found_exclude[:2])}." if found_exclude else ""
            return f"Recommend exclusion. '{title_preview}' does not appear to meet core inclusion criteria.{keywords_note}"

    @staticmethod
    def score_to_decision(score: float, threshold: float = 1.0) -> str:
        """Convert score to include/exclude decision."""
        return "include" if score > threshold else "exclude"

    def generate_mock_clusters(self, num_papers: int, clusters_type: Optional[str] = None) -> List[Dict[str, Any]]:
        """
        Generate mock thematic clusters based on number of papers.

        Args:
            num_papers: Total number of papers
            clusters_type: Filter clusters - 'include', 'exclude', or None for both
        """
        if clusters_type == "include":
            templates = [c for c in self.CLUSTER_TEMPLATES if c["cluster_type"] == "include"]
        elif clusters_type == "exclude":
            templates = [c for c in self.CLUSTER_TEMPLATES if c["cluster_type"] == "exclude"]
        else:
            templates = self.CLUSTER_TEMPLATES

        if num_papers < 50:
            num_clusters = 3
        elif num_papers < 200:
            num_clusters = 5
        elif num_papers < 500:
            num_clusters = 7
        else:
            num_clusters = min(8, len(templates))

        return templates[:num_clusters]

    def assign_cluster(self, score: float, abstract: str, clusters: List[Dict]) -> List[Dict]:
        """
        Assign paper to clusters based on content and score.

        Returns:
            List of assigned cluster objects (with cluster_name and related_criteria)
        """
        text = abstract.lower() if abstract else ""
        assigned = []

        def add_cluster(cluster_obj):
            """Add a cluster if not already assigned."""
            if cluster_obj["cluster_name"] not in [c["cluster_name"] for c in assigned]:
                assigned.append(
                    {
                        "cluster_name": cluster_obj["cluster_name"],
                        "cluster_type": cluster_obj.get("cluster_type", "exclude"),
                        "cluster_description": cluster_obj.get("cluster_description", ""),
                        "related_criteria": cluster_obj.get("related_criteria", []),
                    }
                )

        text_hash = hash(text[:100] if len(text) > 100 else text)

        # High-scoring papers go to treatment efficacy, possibly with additional clusters
        if score >= 4.0:
            add_cluster(clusters[0])
            if "safety" in text or "adverse" in text:
                for c in clusters:
                    if "safety" in c["cluster_name"].lower():
                        add_cluster(c)
                        break
            elif "biomarker" in text or "molecular" in text:
                for c in clusters:
                    if "biomarker" in c["cluster_name"].lower():
                        add_cluster(c)
                        break
            elif "quality" in text or "qol" in text:
                for c in clusters:
                    if "quality" in c["cluster_name"].lower():
                        add_cluster(c)
                        break
            return assigned

        # Low-scoring papers
        if score < 2.0:
            if any(kw in text for kw in ["review", "meta-analysis", "systematic"]):
                for c in clusters:
                    if "review" in c["cluster_name"].lower():
                        add_cluster(c)
                        break
            if any(kw in text for kw in ["protocol", "design", "methodology"]):
                for c in clusters:
                    if "methodology" in c["cluster_name"].lower():
                        add_cluster(c)
                        break
            if any(kw in text for kw in ["cost", "economic", "budget"]):
                for c in clusters:
                    if "economic" in c["cluster_name"].lower():
                        add_cluster(c)
                        break
            if not assigned:
                add_cluster(clusters[1])
            return assigned

        # Medium scores - assign based on content
        if "patient" in text or "population" in text or "cohort" in text:
            for c in clusters:
                if "patient" in c["cluster_name"].lower():
                    add_cluster(c)
                    break

        if "biomarker" in text or "molecular" in text or "genetic" in text:
            for c in clusters:
                if "biomarker" in c["cluster_name"].lower():
                    add_cluster(c)
                    break

        if "safety" in text or "adverse" in text or "toxicity" in text:
            for c in clusters:
                if "safety" in c["cluster_name"].lower():
                    add_cluster(c)
                    break

        if "quality" in text or "qol" in text or "patient-reported" in text:
            for c in clusters:
                if "quality" in c["cluster_name"].lower():
                    add_cluster(c)
                    break

        # If nothing matched, assign based on hash for variety
        if not assigned:
            cluster_idx = text_hash % len(clusters)
            add_cluster(clusters[cluster_idx])
            if text_hash % 3 == 0 and len(clusters) > 2:
                second_idx = (cluster_idx + 1) % len(clusters)
                add_cluster(clusters[second_idx])

        return assigned

    def cancel(self):
        """Cancel ongoing screening operation."""
        self._cancelled = True

    def reset(self):
        """Reset cancelled state."""
        self._cancelled = False

    async def screen_papers(
        self,
        papers_df: pd.DataFrame,
        criteria: Optional[List[Dict]] = None,
        questions: Optional[List[str]] = None,
        progress_callback: Optional[Callable[[int, int, str], None]] = None,
        repetitions: int = 5,
        clusters_type: Optional[str] = None,
        threshold: float = 2.5,
        preserve_existing_scores: bool = True,
    ) -> Tuple[pd.DataFrame, List[Dict[str, Any]]]:
        """
        Simulate the 4-stage AI screening pipeline.

        Args:
            papers_df: DataFrame with papers to screen
            criteria: List of criteria (unused in mock, kept for interface compatibility)
            questions: List of research questions (unused in mock)
            progress_callback: Callback(current, total, stage_name)
            repetitions: Number of scoring repetitions
            clusters_type: Filter clusters - 'include', 'exclude', or None for both
            threshold: Score threshold for include/exclude (default 2.5)
            preserve_existing_scores: If True, preserve existing ai_score values

        Returns:
            Tuple of (results_df with AI columns, clusters list)
        """
        self.reset()
        results = papers_df.copy()
        total = len(papers_df)

        # Missing dict keys become NaN (float) when list[dict] → DataFrame; downstream
        # str ops (len/slice/split) crash on NaN. Coerce text cols once.
        for _col in ("title", "abstract"):
            if _col in results.columns:
                results[_col] = results[_col].fillna("").astype(str)
            else:
                results[_col] = ""

        # Convert ai_decision column to object dtype if it exists
        if "ai_decision" in results.columns:
            results["ai_decision"] = results["ai_decision"].astype(object)

        logger.info(f"[MOCK] Starting AI screening: {total} papers")

        if total == 0:
            logger.warning("[MOCK] No papers to screen")
            return results, []

        # Calculate delay
        total_iterations = total * 3
        stage_delay = min(self.delay_per_paper, self.TOTAL_TARGET_TIME / max(total_iterations, 1))

        # Check for existing scores
        has_ai_scores = (
            preserve_existing_scores and "ai_score" in papers_df.columns and papers_df["ai_score"].notna().any()
        )
        has_reference_scores = (
            preserve_existing_scores
            and "reference_score" in papers_df.columns
            and papers_df["reference_score"].notna().any()
        )
        has_existing_scores = has_ai_scores or has_reference_scores

        if has_existing_scores:
            score_col = "ai_score" if has_ai_scores else "reference_score"
            preserved_count = papers_df[score_col].notna().sum()
            logger.info(f"[MOCK] Preserving {preserved_count} {score_col} values")

        # Initialize result columns
        results["score_list"] = None
        if not has_existing_scores:
            results["ai_score_min"] = None
            results["ai_score_max"] = None
            results["ai_score"] = None
        results["ai_reasoning"] = None
        results["evidence"] = None
        results["ai_decision"] = None
        results["assigned_clusters"] = None

        progress_interval = max(10, total // 50)

        logger.info(f"[MOCK] Stage 1/4: Scoring {total} papers")

        # Stage 1: Labelling (scoring)
        for idx, (row_idx, row) in enumerate(results.iterrows()):
            if self._cancelled:
                break

            existing_score = None
            if has_existing_scores:
                ai_score_val = row.get("ai_score")
                ref_score_val = row.get("reference_score")
                if pd.notna(ai_score_val):
                    existing_score = float(ai_score_val)
                elif pd.notna(ref_score_val):
                    existing_score = float(ref_score_val)

            if existing_score is not None:
                score = existing_score
                existing_min = row.get("ai_score_min")
                existing_max = row.get("ai_score_max")
                score_min = float(existing_min) if pd.notna(existing_min) else max(1.0, score - 0.2)
                score_max = float(existing_max) if pd.notna(existing_max) else min(5.0, score + 0.2)
                score_list = [score]
            else:
                base_score = self.calculate_mock_score(row.get("abstract", ""), row.get("title", ""))
                rng = random.Random(hash(str(row.get("abstract", ""))[:50]))
                score_list = [max(1.0, min(5.0, base_score + rng.uniform(-0.3, 0.3))) for _ in range(repetitions)]
                score = sum(score_list) / len(score_list)
                score_min = min(score_list)
                score_max = max(score_list)

            results.at[row_idx, "score_list"] = score_list
            results.at[row_idx, "ai_score_min"] = score_min
            results.at[row_idx, "ai_score_max"] = score_max
            results.at[row_idx, "ai_score"] = score

            if progress_callback and (idx % progress_interval == 0 or idx == total - 1):
                progress_callback(idx + 1, total * 4, "Scoring papers")

            if stage_delay > 0:
                await asyncio.sleep(stage_delay)

        if self._cancelled:
            logger.warning("[MOCK] Screening cancelled after Stage 1")
            return results, []

        logger.info("[MOCK] Stage 2/4: Generating reasoning")

        # Stage 2: Reasoning
        for idx, (row_idx, row) in enumerate(results.iterrows()):
            if self._cancelled:
                break

            reasoning = self.generate_mock_reasoning(row.get("title", ""), row.get("abstract", ""), row["ai_score"])
            evidence = self.generate_mock_evidence(row.get("title", ""), row.get("abstract", ""), row["ai_score"])
            results.at[row_idx, "ai_reasoning"] = reasoning
            results.at[row_idx, "evidence"] = evidence
            results.at[row_idx, "ai_decision"] = self.score_to_decision(row["ai_score"], threshold=threshold)

            if progress_callback and (idx % progress_interval == 0 or idx == total - 1):
                progress_callback(total + idx + 1, total * 4, "Generating explanations")

            if stage_delay > 0:
                await asyncio.sleep(stage_delay)

        if self._cancelled:
            logger.warning("[MOCK] Screening cancelled after Stage 2")
            return results, []

        logger.info(f"[MOCK] Stage 3/4: Generating clusters (filter: {clusters_type or 'both'})")

        # Stage 3: Clustering
        clusters = self.generate_mock_clusters(total, clusters_type=clusters_type)
        if progress_callback:
            progress_callback(total * 2 + 1, total * 4, "Creating thematic clusters")

        if self._cancelled:
            logger.warning("[MOCK] Screening cancelled after Stage 3")
            return results, clusters

        logger.info(f"[MOCK] Stage 4/4: Assigning clusters ({len(clusters)} clusters)")

        # Stage 4: Cluster assignment
        for idx, (row_idx, row) in enumerate(results.iterrows()):
            if self._cancelled:
                break

            assigned = self.assign_cluster(row["ai_score"], row.get("abstract", ""), clusters)
            results.at[row_idx, "assigned_clusters"] = assigned

            if progress_callback and (idx % progress_interval == 0 or idx == total - 1):
                progress_callback(total * 3 + idx + 1, total * 4, "Assigning papers to clusters")

            if stage_delay > 0:
                await asyncio.sleep(stage_delay)

        # Log completion
        include_count = len(results[results["ai_decision"] == "include"])
        exclude_count = len(results[results["ai_decision"] == "exclude"])
        logger.info(
            f"[MOCK] Screening complete: {total} papers | "
            f"Include: {include_count} | Exclude: {exclude_count} | Clusters: {len(clusters)}"
        )

        return results, clusters

    def generate_criteria_suggestions(
        self, current_criteria: List[Dict], research_questions: List[str], sample_papers: Optional[pd.DataFrame] = None
    ) -> Dict[str, Any]:
        """Generate mock criteria refinement suggestions."""
        suggestions = {
            "additions": [
                {
                    "type": "include",
                    "name": "Study Type",
                    "value": "Randomized controlled trials and prospective cohort studies",
                    "rationale": "Ensures high-quality evidence for treatment efficacy questions",
                },
                {
                    "type": "exclude",
                    "name": "Publication Type",
                    "value": "Conference abstracts without full publication",
                    "rationale": "Limited data availability for extraction",
                },
            ],
            "modifications": [],
            "removals": [],
            "overall_assessment": "Current criteria appear well-structured. Consider adding study type specifications for clearer screening guidance.",
        }

        for criteria in current_criteria[:2]:
            suggestions["modifications"].append(
                {
                    "original": criteria,
                    "suggested_value": criteria.get("criteria_value", "") + " (with outcome data)",
                    "rationale": "Adding outcome data requirement improves specificity",
                }
            )

        return suggestions

    def generate_exclusion_criteria_from_context(
        self, project_description: str, research_questions: List[str]
    ) -> List[Dict[str, Any]]:
        """
        Generate initial exclusion criteria suggestions based on project context.

        Mock implementation that returns common exclusion criteria templates
        relevant to systematic reviews.
        """
        desc_lower = project_description.lower() if project_description else ""
        questions_text = " ".join(research_questions).lower() if research_questions else ""
        context = desc_lower + " " + questions_text

        suggestions = []

        # Always suggest common exclusion criteria
        base_suggestions = [
            {
                "criteria_name": "Study Design",
                "criteria_type": "exclude",
                "criteria_value": "Review articles, systematic reviews, meta-analyses, editorials, commentaries, letters to editor",
                "rationale": "Secondary research and opinion pieces do not provide primary data for synthesis",
            },
            {
                "criteria_name": "Publication Type",
                "criteria_type": "exclude",
                "criteria_value": "Conference abstracts without full publication, protocols, study registrations",
                "rationale": "Insufficient data for extraction and quality assessment",
            },
            {
                "criteria_name": "Study Design",
                "criteria_type": "exclude",
                "criteria_value": "Case reports and case series with <5 patients",
                "rationale": "Very small studies have limited generalizability",
            },
        ]

        suggestions.extend(base_suggestions)

        # Domain-specific suggestions
        if any(term in context for term in ["oncology", "cancer", "tumor", "tumour", "malignant"]):
            suggestions.append(
                {
                    "criteria_name": "Population",
                    "criteria_type": "exclude",
                    "criteria_value": "Benign tumors or non-malignant conditions",
                    "rationale": "Focus on malignant disease per research questions",
                }
            )

        if any(term in context for term in ["adult", "elderly", "geriatric"]):
            suggestions.append(
                {
                    "criteria_name": "Population",
                    "criteria_type": "exclude",
                    "criteria_value": "Pediatric populations (patients <18 years)",
                    "rationale": "Review focused on adult populations",
                }
            )

        if any(term in context for term in ["pediatric", "child", "children", "adolescent"]):
            suggestions.append(
                {
                    "criteria_name": "Population",
                    "criteria_type": "exclude",
                    "criteria_value": "Adult populations (patients >=18 years)",
                    "rationale": "Review focused on pediatric populations",
                }
            )

        if any(term in context for term in ["rct", "randomized", "randomised", "clinical trial"]):
            suggestions.append(
                {
                    "criteria_name": "Study Design",
                    "criteria_type": "exclude",
                    "criteria_value": "Non-randomized studies, observational studies without control group",
                    "rationale": "Review focused on randomized evidence",
                }
            )

        if any(term in context for term in ["efficacy", "effectiveness", "outcome"]):
            suggestions.append(
                {
                    "criteria_name": "Outcome",
                    "criteria_type": "exclude",
                    "criteria_value": "Studies not reporting clinical outcomes or efficacy measures",
                    "rationale": "Primary focus on treatment effectiveness",
                }
            )

        if any(term in context for term in ["human", "patient", "clinical"]):
            suggestions.append(
                {
                    "criteria_name": "Population",
                    "criteria_type": "exclude",
                    "criteria_value": "Animal studies, in vitro studies, computational/modeling studies",
                    "rationale": "Review focused on human clinical evidence",
                }
            )

        # Add language suggestion
        suggestions.append(
            {
                "criteria_name": "Publication",
                "criteria_type": "exclude",
                "criteria_value": "Non-English language publications",
                "rationale": "Language restriction for feasibility of review",
            }
        )

        return suggestions

    def generate_from_conflicts(self, conflicts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """Generate mock refined criteria derived from screening conflicts.

        Returns criteria shaped for the /criteria/refine endpoint (category/text/
        description/criterion_type/confidence/rationale keys).
        """
        n = len(conflicts) if conflicts else 0
        return [
            {
                "category": "Study Design",
                "text": "Exclude retrospective observational studies without a control arm",
                "description": "Conflicts frequently arose on studies lacking concurrent controls.",
                "criterion_type": "exclude",
                "confidence": 0.72,
                "rationale": f"Derived from {n} reviewer conflict(s) on study design.",
            },
            {
                "category": "Outcome Reporting",
                "text": "Exclude studies that do not report the primary outcome quantitatively",
                "description": "Disagreements clustered around studies with narrative-only outcome reporting.",
                "criterion_type": "exclude",
                "confidence": 0.65,
                "rationale": f"Pattern across {n} conflict(s) flagged insufficient outcome data.",
            },
        ]
