"""
Base worker class for Redis Streams consumers.

Each worker:
  1. Claims messages from its stream via XREADGROUP
  2. Processes the message (subclass implements _process)
  3. ACKs the message on success
  4. On failure: retries with exponential backoff up to max_attempts,
     then routes to DLQ

The claim loop also runs XAUTOCLAIM at startup to reclaim
messages abandoned by crashed workers.
"""
from __future__ import annotations

import asyncio
import random
import uuid
from abc import ABC, abstractmethod
from typing import Any

import redis.asyncio as aioredis
import structlog

from app.db.repositories.job import JobRepository
from app.db.models.job import JobStatus
from app.services.queue import QueueService

log = structlog.get_logger(__name__)

# Time to wait when the stream has no new messages
_BLOCK_MS = 5_000  # 5 seconds
# Idle threshold for XAUTOCLAIM (reclaim messages idle longer than this)
_AUTOCLAIM_MIN_IDLE_MS = 300_000  # 5 minutes


class BaseWorker(ABC):
    """Abstract base for Redis Streams workers."""

    stream: str          # subclass sets this
    group: str           # subclass sets this
    consumer_name: str   # subclass sets this (unique per process/thread)

    def __init__(
        self,
        redis_client: aioredis.Redis,
        queue: QueueService,
    ) -> None:
        self._r = redis_client
        self._queue = queue
        self._running = False

    async def start(self) -> None:
        self._running = True
        log.info("worker.start", stream=self.stream, consumer=self.consumer_name)

        # Reclaim any messages abandoned by previously crashed workers
        await self._autoclaim_pending()

        while self._running:
            try:
                messages = await self._r.xreadgroup(
                    groupname=self.group,
                    consumername=self.consumer_name,
                    streams={self.stream: ">"},
                    count=1,
                    block=_BLOCK_MS,
                )
                if not messages:
                    continue

                for _stream, entries in messages:
                    for msg_id, fields in entries:
                        await self._handle_message(msg_id, fields)

            except asyncio.CancelledError:
                break
            except Exception as exc:
                log.error("worker.loop_error", stream=self.stream, error=str(exc))
                await asyncio.sleep(1)

        log.info("worker.stopped", stream=self.stream)

    def stop(self) -> None:
        self._running = False

    async def _handle_message(self, msg_id: bytes | str, fields: dict[bytes, bytes]) -> None:
        # Decode fields
        decoded = {
            (k.decode() if isinstance(k, bytes) else k): (
                v.decode() if isinstance(v, bytes) else v
            )
            for k, v in fields.items()
        }
        job_id = decoded.get("job_id", "unknown")

        log.info("worker.message_received", stream=self.stream, msg_id=msg_id, job_id=job_id)

        try:
            await self._process(decoded)
            await self._r.xack(self.stream, self.group, msg_id)
            log.info("worker.message_done", stream=self.stream, job_id=job_id)

        except Exception as exc:
            log.error(
                "worker.message_failed",
                stream=self.stream,
                job_id=job_id,
                error=str(exc),
                exc_info=True,
            )
            # Do NOT ack — message stays in PEL; XAUTOCLAIM will reclaim
            # The job repository handles retry counting and DLQ routing

    async def _autoclaim_pending(self) -> None:
        """Reclaim messages that have been idle (worker crashed) beyond the threshold."""
        try:
            result = await self._r.xautoclaim(
                name=self.stream,
                groupname=self.group,
                consumername=self.consumer_name,
                min_idle_time=_AUTOCLAIM_MIN_IDLE_MS,
                start_id="0-0",
                count=10,
            )
            # result is (next_id, messages, deleted_ids)
            messages = result[1] if result else []
            if messages:
                log.info("worker.autoclaim", stream=self.stream, count=len(messages))
                for msg_id, fields in messages:
                    await self._handle_message(msg_id, fields)
        except Exception as exc:
            log.warning("worker.autoclaim_error", stream=self.stream, error=str(exc))

    @staticmethod
    def _backoff(attempt: int) -> float:
        """Exponential backoff with jitter: 2^attempt * 10s ± 10%."""
        base = (2 ** attempt) * 10
        jitter = base * 0.1 * random.random()
        return base + jitter

    @abstractmethod
    async def _process(self, fields: dict[str, str]) -> None:
        """Process a single message. Raise on failure."""
        ...
