"""Embed worker — consumes ai:embed stream, writes pgvector embeddings."""
from __future__ import annotations

import uuid

import redis.asyncio as aioredis
import structlog
import structlog.contextvars

from app.config import get_settings
from app.db.models.job import JobStatus, JobType
from app.db.models.media import LinkStatus, MediaStatus
from app.db.repositories.job import JobRepository
from app.db.repositories.media import MediaRepository
from app.db.session import AsyncSessionLocal
from app.exceptions import AIServiceError
from app.services.ai import AIService
from app.services.queue import STREAM_NOTIF, QueueService
from app.workers.base import BaseWorker

log = structlog.get_logger(__name__)
settings = get_settings()


class EmbedWorker(BaseWorker):
    stream = "ai:embed"
    group = "ai:embed-workers"

    def __init__(
        self,
        redis_client: aioredis.Redis,
        queue: QueueService,
        ai: AIService,
        worker_id: int = 0,
    ) -> None:
        super().__init__(redis_client, queue)
        self._ai = ai
        self.consumer_name = f"embed-worker-{worker_id}"

    async def _process(self, fields: dict[str, str]) -> None:
        job_id = uuid.UUID(fields["job_id"])
        structlog.contextvars.bind_contextvars(job_id=str(job_id), worker=self.consumer_name)

        async with AsyncSessionLocal() as session:
            job_repo = JobRepository(session)
            media_repo = MediaRepository(session)

            job = await job_repo.get(job_id)
            if not job:
                log.error("embed.job_not_found", job_id=str(job_id))
                return

            lock_key = f"job:{job_id}:lock"
            acquired = await self._r.set(lock_key, self.consumer_name, nx=True, ex=300)
            if not acquired:
                log.warning("embed.lock_not_acquired", job_id=str(job_id))
                return

            try:
                claimed = await job_repo.transition(job_id, JobStatus.PENDING, JobStatus.RUNNING)
                if not claimed:
                    return
                await session.commit()

                entity_id = job.entity_id
                entity_type = job.entity_type or "media"

                # Build embedding input from payload (passed by analyze worker)
                title = job.payload.get("title", "")
                summary = job.payload.get("summary", "")
                keywords = job.payload.get("keywords", [])

                embedding_input = AIService.build_embedding_input(title, summary, keywords)

                # Generate embedding (sync call — OpenAI client not async)
                embedding = self._ai.generate_embedding(embedding_input)

                model_version = settings.embedding_model

                # Persist embedding
                if entity_type == "media":
                    await media_repo.set_media_embedding(entity_id, embedding, model_version)
                else:
                    await media_repo.set_link_embedding(entity_id, embedding, model_version)

                await session.commit()

                # Enqueue notification
                notif_job = await job_repo.create(
                    entity_id=entity_id,
                    entity_type=entity_type,
                    user_id=job.user_id,
                    canonical_url_hash=job.canonical_url_hash or "",
                    job_type=JobType.NOTIFY,
                    payload=job.payload,
                )
                await self._queue.enqueue(STREAM_NOTIF, job_id=str(notif_job.id))
                await job_repo.transition(job_id, JobStatus.RUNNING, JobStatus.DONE)
                await session.commit()

                log.info("embed.complete", entity_id=str(entity_id), model=model_version)

            except AIServiceError as exc:
                await self._handle_failure(job_id, job_repo, session, exc)
            finally:
                await self._r.delete(lock_key)

        structlog.contextvars.unbind_contextvars("job_id", "worker")

    async def _handle_failure(
        self, job_id: uuid.UUID, job_repo: JobRepository, session: object, exc: Exception
    ) -> None:
        from sqlalchemy.ext.asyncio import AsyncSession
        assert isinstance(session, AsyncSession)
        attempts = await job_repo.increment_attempts(job_id)
        await job_repo.append_history_event(job_id, "FAILED", str(exc))
        is_permanent = getattr(exc, "is_permanent", False)
        if is_permanent or attempts >= settings.job_max_attempts:
            await job_repo.transition(job_id, JobStatus.RUNNING, JobStatus.FAILED, error_message=str(exc))
            await self._queue.to_dlq(self.stream, str(job_id), str(exc), attempts)
        else:
            await job_repo.transition(job_id, JobStatus.RUNNING, JobStatus.FAILED)
        await session.commit()
