from __future__ import annotations

import uuid
from datetime import datetime, timezone

import structlog
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession

from app.db.models.media import AIAnalysis, EntityTag, Link, LinkStatus, MediaItem, MediaStatus, Tag

log = structlog.get_logger(__name__)


class MediaRepository:
    def __init__(self, session: AsyncSession) -> None:
        self._s = session

    # ── MediaItem ─────────────────────────────────────────────────────────────

    async def create_media_item(
        self,
        user_id: uuid.UUID,
        source_url: str,
        canonical_url_hash: str,
        source_platform: str = "web",
    ) -> MediaItem:
        item = MediaItem(
            user_id=user_id,
            source_url=source_url,
            canonical_url_hash=canonical_url_hash,
            source_platform=source_platform,
            status=MediaStatus.PENDING,
        )
        self._s.add(item)
        await self._s.flush()
        log.info("media.created", entity_id=str(item.id), platform=source_platform)
        return item

    async def get_media_item(self, item_id: uuid.UUID) -> MediaItem | None:
        result = await self._s.execute(select(MediaItem).where(MediaItem.id == item_id))
        return result.scalar_one_or_none()

    async def update_media_status(
        self,
        item_id: uuid.UUID,
        status: str,
        **kwargs: object,
    ) -> None:
        values: dict[str, object] = {"status": status, "updated_at": datetime.now(tz=timezone.utc)}
        values.update(kwargs)
        await self._s.execute(
            update(MediaItem).where(MediaItem.id == item_id).values(**values)
        )
        await self._s.flush()

    async def set_media_embedding(
        self,
        item_id: uuid.UUID,
        embedding: list[float],
        model_version: str,
    ) -> None:
        await self._s.execute(
            update(MediaItem)
            .where(MediaItem.id == item_id)
            .values(
                embedding=embedding,
                embedding_model_version=model_version,
                status=MediaStatus.COMPLETE,
                updated_at=datetime.now(tz=timezone.utc),
            )
        )
        await self._s.flush()

    async def list_for_user(
        self,
        user_id: uuid.UUID,
        limit: int = 20,
        offset: int = 0,
    ) -> list[MediaItem]:
        result = await self._s.execute(
            select(MediaItem)
            .where(MediaItem.user_id == user_id)
            .order_by(MediaItem.created_at.desc())
            .limit(limit)
            .offset(offset)
        )
        return list(result.scalars().all())

    # ── Link ──────────────────────────────────────────────────────────────────

    async def create_link(
        self,
        user_id: uuid.UUID,
        url: str,
        canonical_url_hash: str,
        page_title: str | None = None,
        og_description: str | None = None,
    ) -> Link:
        link = Link(
            user_id=user_id,
            url=url,
            canonical_url_hash=canonical_url_hash,
            page_title=page_title,
            og_description=og_description,
            status=LinkStatus.PENDING,
        )
        self._s.add(link)
        await self._s.flush()
        log.info("link.created", entity_id=str(link.id))
        return link

    async def get_link(self, link_id: uuid.UUID) -> Link | None:
        result = await self._s.execute(select(Link).where(Link.id == link_id))
        return result.scalar_one_or_none()

    async def update_link_status(
        self,
        link_id: uuid.UUID,
        status: str,
        **kwargs: object,
    ) -> None:
        values: dict[str, object] = {"status": status, "updated_at": datetime.now(tz=timezone.utc)}
        values.update(kwargs)
        await self._s.execute(
            update(Link).where(Link.id == link_id).values(**values)
        )
        await self._s.flush()

    async def set_link_embedding(
        self,
        link_id: uuid.UUID,
        embedding: list[float],
        model_version: str,
    ) -> None:
        await self._s.execute(
            update(Link)
            .where(Link.id == link_id)
            .values(
                embedding=embedding,
                embedding_model_version=model_version,
                status=LinkStatus.COMPLETE,
                updated_at=datetime.now(tz=timezone.utc),
            )
        )
        await self._s.flush()

    # ── AI Analysis ───────────────────────────────────────────────────────────

    async def save_analysis(
        self,
        entity_id: uuid.UUID,
        entity_type: str,
        summary: str,
        keywords: list[str],
        entities: dict,
        llm_model: str,
        pipeline_version: str,
        prompt_tokens: int,
        completion_tokens: int,
    ) -> AIAnalysis:
        analysis = AIAnalysis(
            entity_id=entity_id,
            entity_type=entity_type,
            summary=summary,
            keywords=keywords,
            entities_json=entities,
            llm_model=llm_model,
            pipeline_version=pipeline_version,
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
        )
        self._s.add(analysis)
        await self._s.flush()
        return analysis

    async def get_analysis(self, entity_id: uuid.UUID, entity_type: str) -> AIAnalysis | None:
        result = await self._s.execute(
            select(AIAnalysis)
            .where(AIAnalysis.entity_id == entity_id, AIAnalysis.entity_type == entity_type)
            .order_by(AIAnalysis.created_at.desc())
            .limit(1)
        )
        return result.scalar_one_or_none()

    # ── Tags ──────────────────────────────────────────────────────────────────

    async def upsert_tag(self, user_id: uuid.UUID, name: str) -> Tag:
        from sqlalchemy.dialects.postgresql import insert as pg_insert

        slug = name.lower().replace(" ", "_").replace("-", "_")
        stmt = (
            pg_insert(Tag)
            .values(user_id=user_id, name=name, slug=slug)
            .on_conflict_do_nothing(constraint="uq_tags_user_slug")
            .returning(Tag)
        )
        result = await self._s.execute(stmt)
        await self._s.flush()
        row = result.scalar_one_or_none()
        if row is None:
            # Already existed — fetch it
            existing = await self._s.execute(
                select(Tag).where(Tag.user_id == user_id, Tag.slug == slug)
            )
            row = existing.scalar_one()
        return row

    async def attach_tags(
        self,
        entity_id: uuid.UUID,
        entity_type: str,
        user_id: uuid.UUID,
        tag_names: list[str],
        source: str = "ai_generated",
    ) -> None:
        for name in tag_names[:7]:  # max 7 tags per item
            name = name.strip()
            if len(name) < 2:
                continue
            tag = await self.upsert_tag(user_id, name)
            # Upsert junction
            from sqlalchemy.dialects.postgresql import insert as pg_insert

            stmt = (
                pg_insert(EntityTag)
                .values(entity_id=entity_id, entity_type=entity_type, tag_id=tag.id, source=source)
                .on_conflict_do_nothing()
            )
            await self._s.execute(stmt)
        await self._s.flush()
