"""Tests for crystallise.batch.runner.BatchRunner."""

from crystallise.batch.runner import BatchRunner


def _ok_fn(item):
    """Simple fn that returns (result, usage, error)."""
    return {"id": item, "status": "done"}, {"input_tokens": 10, "output_tokens": 5}, None


def _error_fn(item):
    """Fn that always raises."""
    raise RuntimeError(f"Failed on {item}")


class TestBatchRunner:
    def test_processes_all_items(self):
        runner = BatchRunner(items=[1, 2, 3], fn=_ok_fn, max_workers=2)
        runner.run_next_batch(batch_size=10)
        assert len(runner.results) == 3
        assert len(runner.errors) == 0
        assert not runner.has_remaining()

    def test_handles_errors_gracefully(self):
        runner = BatchRunner(items=[1, 2, 3], fn=_error_fn, max_workers=2)
        runner.run_next_batch(batch_size=10)
        assert len(runner.results) == 0
        assert len(runner.errors) == 3
        assert all("Failed on" in e for e in runner.errors)

    def test_tracks_progress_via_callback(self):
        progress_log = []

        def cb(done, total):
            progress_log.append((done, total))

        runner = BatchRunner(items=[1, 2, 3], fn=_ok_fn, max_workers=1)
        runner.run_next_batch(batch_size=10, progress_cb=cb)
        assert len(progress_log) == 3
        # Final callback should have done == total
        assert progress_log[-1] == (3, 3)

    def test_checkpointing(self):
        runner = BatchRunner(items=[1, 2, 3, 4, 5], fn=_ok_fn, max_workers=2)
        # Process first batch of 2
        runner.run_next_batch(batch_size=2)
        assert runner.has_remaining()
        assert len(runner.results) == 2
        # Process next batch of 2
        runner.run_next_batch(batch_size=2)
        assert runner.has_remaining()
        assert len(runner.results) == 4
        # Process final batch
        runner.run_next_batch(batch_size=2)
        assert not runner.has_remaining()
        assert len(runner.results) == 5

    def test_total_time_seconds(self):
        runner = BatchRunner(items=[1], fn=_ok_fn, max_workers=1)
        assert runner.total_time_seconds() == 0.0
        runner.run_next_batch(batch_size=10)
        assert runner.total_time_seconds() > 0.0

    def test_mixed_success_and_failure(self):
        def mixed_fn(item):
            if item % 2 == 0:
                raise ValueError(f"even: {item}")
            return {"id": item}, None, None

        runner = BatchRunner(items=[1, 2, 3, 4, 5], fn=mixed_fn, max_workers=2)
        runner.run_next_batch(batch_size=10)
        assert len(runner.results) == 3  # 1, 3, 5 succeed
        assert len(runner.errors) == 2  # 2, 4 fail
