from __future__ import annotations

import uuid
from datetime import datetime, timezone

import structlog
from sqlalchemy import select, update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession

from app.db.models.job import ALLOWED_TRANSITIONS, Job, JobStatus
from app.exceptions import InvalidJobTransitionError

log = structlog.get_logger(__name__)


class JobRepository:
    def __init__(self, session: AsyncSession) -> None:
        self._s = session

    async def create(
        self,
        entity_id: uuid.UUID,
        entity_type: str,
        user_id: uuid.UUID,
        canonical_url_hash: str,
        job_type: str,
        payload: dict | None = None,
        max_attempts: int = 3,
    ) -> Job:
        """
        Create a job with deduplication via the partial unique index on canonical_url_hash.
        Returns the new job.
        Raises if a duplicate is somehow not caught (caller should check find_active_by_hash first).
        """
        job = Job(
            entity_id=entity_id,
            entity_type=entity_type,
            user_id=user_id,
            canonical_url_hash=canonical_url_hash,
            job_type=job_type,
            status=JobStatus.PENDING,
            payload=payload or {},
            max_attempts=max_attempts,
        )
        self._s.add(job)
        await self._s.flush()
        log.info("job.created", job_id=str(job.id), job_type=job_type, entity_id=str(entity_id))
        return job

    async def get(self, job_id: uuid.UUID) -> Job | None:
        result = await self._s.execute(select(Job).where(Job.id == job_id))
        return result.scalar_one_or_none()

    async def find_active_by_hash(self, canonical_url_hash: str) -> Job | None:
        """Return any non-failed, non-expired job for this URL hash."""
        result = await self._s.execute(
            select(Job)
            .where(
                Job.canonical_url_hash == canonical_url_hash,
                Job.status.not_in([JobStatus.FAILED, JobStatus.EXPIRED, JobStatus.DONE]),
            )
            .limit(1)
        )
        return result.scalar_one_or_none()

    async def transition(
        self,
        job_id: uuid.UUID,
        from_status: str,
        to_status: str,
        error_message: str | None = None,
    ) -> bool:
        """
        Atomically transition job status using optimistic locking.
        Returns True if the transition succeeded, False if the job was already
        in a different state (concurrent claim by another worker).
        Raises InvalidJobTransitionError if the transition is not in ALLOWED_TRANSITIONS.
        """
        if to_status not in ALLOWED_TRANSITIONS.get(from_status, set()):
            raise InvalidJobTransitionError(str(job_id), from_status, to_status)

        now = datetime.now(tz=timezone.utc)
        values: dict[str, object] = {"status": to_status}

        if to_status == JobStatus.RUNNING:
            values["started_at"] = now
        elif to_status in (JobStatus.DONE, JobStatus.FAILED, JobStatus.EXPIRED):
            values["completed_at"] = now

        if error_message:
            values["error_message"] = error_message

        result = await self._s.execute(
            update(Job)
            .where(Job.id == job_id, Job.status == from_status)
            .values(**values)
            .returning(Job.id)
        )
        await self._s.flush()
        success = result.scalar_one_or_none() is not None

        if success:
            log.info(
                "job.transition",
                job_id=str(job_id),
                from_status=from_status,
                to_status=to_status,
            )
        else:
            log.warning(
                "job.transition.conflict",
                job_id=str(job_id),
                from_status=from_status,
                to_status=to_status,
            )
        return success

    async def increment_attempts(self, job_id: uuid.UUID) -> int:
        job = await self.get(job_id)
        if job is None:
            return 0
        job.attempts = (job.attempts or 0) + 1
        await self._s.flush()
        return job.attempts

    async def append_history_event(
        self, job_id: uuid.UUID, status: str, detail: str | None = None
    ) -> None:
        job = await self.get(job_id)
        if job is None:
            return
        history = list(job.status_history or [])
        history.append({
            "status": status,
            "at": datetime.now(tz=timezone.utc).isoformat(),
            "detail": detail,
        })
        job.status_history = history
        await self._s.flush()
