"""Batch processing with ThreadPoolExecutor and checkpointing."""
from __future__ import annotations

import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable, Optional


class BatchRunner:
    """
    Run tasks in batches with checkpointing.

    Args:
        items: list of work items
        fn: function(item) -> (result, usage_or_none, error_or_none)
    """

    def __init__(
        self,
        items: list[Any],
        fn: Callable[[Any], tuple[Any, Any, Optional[str]]],
        max_workers: int = 4,
    ):
        self.items = items
        self.fn = fn
        self.max_workers = max_workers
        self._i = 0
        self.results: list[Any] = []
        self.usages: list[Any] = []
        self.errors: list[str] = []
        self.start_time: Optional[float] = None
        self.end_time: Optional[float] = None

    def has_remaining(self) -> bool:
        return self._i < len(self.items)

    def run_next_batch(
        self,
        batch_size: int,
        progress_cb: Callable[[int, int], None] | None = None,
    ) -> tuple[list[Any], list[Any], list[str]]:
        """
        Execute the next batch. Returns (new_results, new_usages, new_errors).
        progress_cb(done_in_batch, total_in_batch) called after each item.
        """
        if self.start_time is None:
            self.start_time = time.time()
        start = self._i
        end = min(len(self.items), start + batch_size)
        batch = self.items[start:end]

        def wrapped(item: Any) -> tuple[Any, Any, Optional[str]]:
            try:
                res, usage, err = self.fn(item)
            except Exception as e:
                res, usage, err = None, None, str(e)
            return res, usage, err

        new_results: list[Any] = []
        new_usages: list[Any] = []
        new_errors: list[str] = []

        with ThreadPoolExecutor(max_workers=self.max_workers) as ex:
            futures = [ex.submit(wrapped, it) for it in batch]
            done_count = 0
            for fut in as_completed(futures):
                res, usage, err = fut.result()
                if err:
                    new_errors.append(err)
                else:
                    new_results.append(res)
                if usage is not None:
                    new_usages.append(usage)
                done_count += 1
                if progress_cb:
                    try:
                        progress_cb(done_count, len(batch))
                    except Exception:
                        pass

        self.results.extend(new_results)
        self.usages.extend(new_usages)
        self.errors.extend(new_errors)
        self._i = end

        if not self.has_remaining():
            self.end_time = time.time()

        return new_results, new_usages, new_errors

    def total_time_seconds(self) -> float:
        if self.start_time is None:
            return 0.0
        if self.end_time is None:
            return time.time() - self.start_time
        return self.end_time - self.start_time
