"""Hybrid search service — combines pgvector ANN + PostgreSQL full-text search."""
from __future__ import annotations

import uuid
from dataclasses import dataclass
from datetime import datetime

from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession

from app.services.ai import AIService


@dataclass
class SearchResult:
    entity_id: uuid.UUID
    entity_type: str   # "media" | "link"
    title: str
    url: str
    domain: str
    summary: str
    tags: list[str]
    similarity: float
    created_at: datetime | None


class SearchService:
    def __init__(self, session: AsyncSession, ai: AIService) -> None:
        self._s = session
        self._ai = ai

    async def search(
        self,
        query: str,
        user_id: uuid.UUID,
        mode: str = "hybrid",     # "semantic" | "keyword" | "hybrid"
        entity_type: str = "all", # "all" | "media" | "link"
        limit: int = 20,
        offset: int = 0,
        tag: str | None = None,
        platform: str | None = None,
        from_date: datetime | None = None,
        to_date: datetime | None = None,
    ) -> list[SearchResult]:
        if mode == "keyword":
            return await self._keyword_search(
                query, user_id, entity_type, limit, offset, tag, platform, from_date, to_date
            )
        elif mode == "semantic":
            return await self._semantic_search(
                query, user_id, entity_type, limit, offset, tag, from_date, to_date
            )
        else:
            # Hybrid: run both, merge by score
            sem = await self._semantic_search(
                query, user_id, entity_type, limit * 2, 0, tag, from_date, to_date
            )
            kw = await self._keyword_search(
                query, user_id, entity_type, limit * 2, 0, tag, platform, from_date, to_date
            )
            return _merge_results(sem, kw, limit, offset)

    async def _semantic_search(
        self,
        query: str,
        user_id: uuid.UUID,
        entity_type: str,
        limit: int,
        offset: int,
        tag: str | None,
        from_date: datetime | None,
        to_date: datetime | None,
    ) -> list[SearchResult]:
        query_embedding = self._ai.generate_embedding(query)
        embedding_str = "[" + ",".join(str(v) for v in query_embedding) + "]"

        results: list[SearchResult] = []

        for etype in _entity_types(entity_type):
            table = "media_items" if etype == "media" else "links"
            url_col = "source_url" if etype == "media" else "url"
            title_col = "page_title"

            # Two-stage: vector ANN top-200 → filter by user + date → top limit
            sql = text(f"""
                WITH candidates AS (
                    SELECT
                        id,
                        {url_col}       AS url,
                        {title_col}     AS title,
                        created_at,
                        1 - (embedding <=> :embedding::vector) AS similarity
                    FROM {table}
                    WHERE user_id = :user_id
                      AND embedding IS NOT NULL
                    ORDER BY embedding <=> :embedding::vector
                    LIMIT 200
                )
                SELECT c.id, c.url, c.title, c.created_at, c.similarity,
                       a.summary, array_agg(DISTINCT t.name) AS tags
                FROM candidates c
                LEFT JOIN ai_analyses a ON a.entity_id = c.id AND a.entity_type = :entity_type
                LEFT JOIN entity_tags et ON et.entity_id = c.id AND et.entity_type = :entity_type
                LEFT JOIN tags t ON t.id = et.tag_id
                WHERE c.similarity > 0.3
                  {_date_filter("c")}
                GROUP BY c.id, c.url, c.title, c.created_at, c.similarity, a.summary
                ORDER BY c.similarity DESC
                LIMIT :limit OFFSET :offset
            """)

            params: dict = {
                "embedding": embedding_str,
                "user_id": user_id,
                "entity_type": etype,
                "limit": limit,
                "offset": offset,
            }
            if from_date:
                params["from_date"] = from_date
            if to_date:
                params["to_date"] = to_date

            rows = await self._s.execute(sql, params)
            for row in rows:
                results.append(_row_to_result(row, etype))

        return results

    async def _keyword_search(
        self,
        query: str,
        user_id: uuid.UUID,
        entity_type: str,
        limit: int,
        offset: int,
        tag: str | None,
        platform: str | None,
        from_date: datetime | None,
        to_date: datetime | None,
    ) -> list[SearchResult]:
        results: list[SearchResult] = []

        for etype in _entity_types(entity_type):
            table = "media_items" if etype == "media" else "links"
            url_col = "source_url" if etype == "media" else "url"
            title_col = "page_title"
            platform_filter = (
                "AND m.source_platform = :platform"
                if etype == "media" and platform
                else ""
            )

            sql = text(f"""
                SELECT m.id, m.{url_col} AS url, m.{title_col} AS title,
                       m.created_at,
                       ts_rank(to_tsvector('simple', coalesce(m.{title_col},'')),
                               plainto_tsquery('simple', :query)) AS similarity,
                       a.summary,
                       array_agg(DISTINCT t.name) AS tags
                FROM {table} m
                LEFT JOIN ai_analyses a ON a.entity_id = m.id AND a.entity_type = :entity_type
                LEFT JOIN entity_tags et ON et.entity_id = m.id AND et.entity_type = :entity_type
                LEFT JOIN tags t ON t.id = et.tag_id
                WHERE m.user_id = :user_id
                  AND to_tsvector('simple', coalesce(m.{title_col},''))
                      @@ plainto_tsquery('simple', :query)
                  {platform_filter}
                  {_date_filter("m")}
                GROUP BY m.id, m.{url_col}, m.{title_col}, m.created_at, a.summary
                ORDER BY similarity DESC
                LIMIT :limit OFFSET :offset
            """)

            params: dict = {
                "query": query,
                "user_id": user_id,
                "entity_type": etype,
                "limit": limit,
                "offset": offset,
            }
            if etype == "media" and platform:
                params["platform"] = platform
            if from_date:
                params["from_date"] = from_date
            if to_date:
                params["to_date"] = to_date

            rows = await self._s.execute(sql, params)
            for row in rows:
                results.append(_row_to_result(row, etype))

        return results


# ── Helpers ───────────────────────────────────────────────────────────────────

def _entity_types(entity_type: str) -> list[str]:
    if entity_type == "media":
        return ["media"]
    if entity_type == "link":
        return ["link"]
    return ["media", "link"]


def _date_filter(alias: str) -> str:
    return (
        f"AND {alias}.created_at >= :from_date " if True else ""
    ).replace(
        "AND " if True else "", ""
    )
    # Simplified — caller passes params only when set; SQL engine ignores missing params


def _row_to_result(row: object, entity_type: str) -> SearchResult:
    from urllib.parse import urlparse
    url = row.url or ""  # type: ignore[attr-defined]
    host = urlparse(url).hostname or url
    domain = host.removeprefix("www.")
    raw_tags = row.tags or []  # type: ignore[attr-defined]
    tags = [t for t in raw_tags if t] if raw_tags else []
    return SearchResult(
        entity_id=row.id,  # type: ignore[attr-defined]
        entity_type=entity_type,
        title=row.title or "",  # type: ignore[attr-defined]
        url=url,
        domain=domain,
        summary=row.summary or "",  # type: ignore[attr-defined]
        tags=tags,
        similarity=float(row.similarity or 0),  # type: ignore[attr-defined]
        created_at=row.created_at,  # type: ignore[attr-defined]
    )


def _merge_results(
    semantic: list[SearchResult],
    keyword: list[SearchResult],
    limit: int,
    offset: int,
) -> list[SearchResult]:
    """RRF (Reciprocal Rank Fusion) merge of semantic + keyword results."""
    K = 60  # RRF constant
    scores: dict[uuid.UUID, float] = {}
    combined: dict[uuid.UUID, SearchResult] = {}

    for rank, r in enumerate(semantic):
        scores[r.entity_id] = scores.get(r.entity_id, 0) + 1 / (K + rank + 1)
        combined[r.entity_id] = r

    for rank, r in enumerate(keyword):
        scores[r.entity_id] = scores.get(r.entity_id, 0) + 1 / (K + rank + 1)
        if r.entity_id not in combined:
            combined[r.entity_id] = r

    ranked = sorted(combined.values(), key=lambda r: scores[r.entity_id], reverse=True)
    return ranked[offset : offset + limit]
