from __future__ import annotations

import hashlib
import ipaddress
import re
import socket
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse

from app.exceptions import SSRFError, ValidationError

# ── Domain allowlist ──────────────────────────────────────────────────────────

ALLOWED_DOMAINS: frozenset[str] = frozenset(
    {
        "twitter.com", "x.com", "t.co",
        "tiktok.com", "vm.tiktok.com",
        "youtube.com", "youtu.be",
        "instagram.com",
        "reddit.com", "redd.it",
        "facebook.com", "fb.com", "fb.watch",
        "linkedin.com",
        "twitch.tv", "clips.twitch.tv",
        "vimeo.com",
        "dailymotion.com",
        "soundcloud.com",
        "open.spotify.com",
        "medium.com",
        "substack.com",
        "github.com",
        "news.ycombinator.com",
    }
)

# Tracking params stripped from all URLs
_TRACKING_PARAMS: frozenset[str] = frozenset(
    {
        "utm_source", "utm_medium", "utm_campaign", "utm_term", "utm_content",
        "fbclid", "gclid", "msclkid", "twclid",
        "ref", "ref_src", "ref_url",
        "si",       # Spotify share ID
        "feature",  # YouTube feature param
        "s",        # Twitter share
        "t",        # Twitter timestamp / TikTok
        "_branch_match_id",
        "igshid",   # Instagram share
    }
)

# Private / reserved IP ranges that should never be contacted
_PRIVATE_NETWORKS = [
    ipaddress.ip_network("10.0.0.0/8"),
    ipaddress.ip_network("172.16.0.0/12"),
    ipaddress.ip_network("192.168.0.0/16"),
    ipaddress.ip_network("127.0.0.0/8"),
    ipaddress.ip_network("169.254.0.0/16"),   # link-local / AWS metadata
    ipaddress.ip_network("::1/128"),
    ipaddress.ip_network("fc00::/7"),
    ipaddress.ip_network("fe80::/10"),
]

URL_MAX_LENGTH = 2048
_SAFE_CHARS_RE = re.compile(r"^[\x20-\x7E\u00A0-\uFFFF]+$")


def _is_private_ip(addr: str) -> bool:
    try:
        ip = ipaddress.ip_address(addr)
    except ValueError:
        return True  # unparseable — treat as unsafe
    return any(ip in net for net in _PRIVATE_NETWORKS)


def validate_url(raw_url: str) -> str:
    """
    Validate and normalize a user-supplied URL.

    Raises ValidationError for invalid/disallowed URLs.
    Raises SSRFError if the URL resolves to a private IP.
    Returns the normalized URL string.
    """
    raw_url = raw_url.strip()

    if len(raw_url) > URL_MAX_LENGTH:
        raise ValidationError(f"URL too long (max {URL_MAX_LENGTH} chars)")

    parsed = urlparse(raw_url)

    if parsed.scheme not in ("http", "https"):
        raise ValidationError("Only http:// and https:// URLs are accepted")

    if not parsed.netloc:
        raise ValidationError("URL has no host")

    host = parsed.hostname or ""
    if not host:
        raise ValidationError("URL has no host")

    # Check against allowed domain list
    normalized_host = host.lower().removeprefix("www.")
    if not any(normalized_host == d or normalized_host.endswith("." + d) for d in ALLOWED_DOMAINS):
        raise ValidationError(
            f"Domain '{host}' is not in the supported list. "
            "Supported: Twitter/X, TikTok, YouTube, Instagram, Reddit, and others."
        )

    # SSRF: resolve DNS and check all returned IPs
    try:
        addr_infos = socket.getaddrinfo(host, None)
    except socket.gaierror as exc:
        raise ValidationError(f"Could not resolve hostname '{host}': {exc}") from exc

    for addr_info in addr_infos:
        ip_str = addr_info[4][0]
        if _is_private_ip(ip_str):
            raise SSRFError(
                f"URL resolves to a private/reserved IP address: {ip_str}"
            )

    return normalize_url(raw_url)


def normalize_url(raw_url: str) -> str:
    """
    Produce a deterministic canonical form for deduplication.
    Does NOT validate — call validate_url() first.
    """
    parsed = urlparse(raw_url)

    scheme = parsed.scheme.lower()
    host = (parsed.hostname or "").lower().removeprefix("www.")
    # Strip default ports
    port = parsed.port
    if (scheme == "http" and port == 80) or (scheme == "https" and port == 443):
        port = None
    netloc = host if port is None else f"{host}:{port}"

    # Platform-specific normalization
    path = parsed.path.rstrip("/") or "/"
    params = parse_qs(parsed.query, keep_blank_values=False)

    if host in ("youtube.com", "youtu.be"):
        # Canonical: https://www.youtube.com/watch?v={id}
        vid = None
        if host == "youtu.be":
            vid = path.lstrip("/")
        elif "v" in params:
            vid = params["v"][0]
        if vid:
            return f"https://www.youtube.com/watch?v={vid}"

    # Strip tracking params
    clean_params = {k: v for k, v in params.items() if k not in _TRACKING_PARAMS}
    # Sort for determinism
    query = urlencode(sorted(clean_params.items()), doseq=True)

    return urlunparse((scheme, netloc, path, "", query, ""))


def url_to_hash(normalized_url: str) -> str:
    """SHA-256 hex digest of the normalized URL — used as canonical_url_hash."""
    return hashlib.sha256(normalized_url.encode()).hexdigest()


def detect_platform(url: str) -> str:
    """Return a short platform slug from a normalized URL."""
    host = urlparse(url).hostname or ""
    host = host.removeprefix("www.")
    mapping = {
        "twitter.com": "twitter", "x.com": "twitter", "t.co": "twitter",
        "tiktok.com": "tiktok", "vm.tiktok.com": "tiktok",
        "youtube.com": "youtube", "youtu.be": "youtube",
        "instagram.com": "instagram",
        "reddit.com": "reddit", "redd.it": "reddit",
        "facebook.com": "facebook", "fb.com": "facebook", "fb.watch": "facebook",
        "vimeo.com": "vimeo",
        "soundcloud.com": "soundcloud",
        "open.spotify.com": "spotify",
    }
    for domain, platform in mapping.items():
        if host == domain or host.endswith("." + domain):
            return platform
    return "web"
