from __future__ import annotations

import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Callable, Awaitable

from config.settings import settings

logger = logging.getLogger(__name__)


@dataclass
class SignalEntry:
    symbol: str
    action: str        # "BUY" | "SELL"
    confidence: float  # 0.0 إلى 1.0
    source: str
    received_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))


class SignalAggregator:
    """
    يجمع الإشارات الواردة في نافذة زمنية ثابتة.
    عند انتهاء النافذة يحسب الـ score ويستدعي callback القرار.
    """

    def __init__(self, on_flush: Callable[[str, float, int], Awaitable[None]]) -> None:
        # on_flush(symbol, score, signal_count) يُستدعى عند انتهاء كل نافذة
        self._on_flush = on_flush
        self._buffer: dict[str, list[SignalEntry]] = {}
        self._timers: dict[str, asyncio.TimerHandle] = {}

    async def receive(
        self, symbol: str, action: str, confidence: float, source: str
    ) -> None:
        symbol = symbol.upper()
        entry = SignalEntry(
            symbol=symbol, action=action, confidence=confidence, source=source
        )

        if symbol not in self._buffer:
            self._buffer[symbol] = []
            # بدء مؤقت النافذة عند وصول أول إشارة
            loop = asyncio.get_event_loop()
            self._timers[symbol] = loop.call_later(
                settings.SIGNAL_WINDOW, self._schedule_flush, symbol
            )
            logger.debug(
                f"نافذة جديدة لـ {symbol} | مدة={settings.SIGNAL_WINDOW}s"
            )

        self._buffer[symbol].append(entry)
        logger.info(
            f"إشارة مستلمة: {action} {symbol} ثقة={confidence:.2f} من={source} "
            f"| عدد_في_النافذة={len(self._buffer[symbol])}"
        )

    def _schedule_flush(self, symbol: str) -> None:
        loop = asyncio.get_event_loop()
        loop.create_task(self._flush(symbol))

    async def _flush(self, symbol: str) -> None:
        entries = self._buffer.pop(symbol, [])
        self._timers.pop(symbol, None)

        if not entries:
            return

        score = self._compute_score(entries)
        logger.info(
            f"النافذة انتهت لـ {symbol} | إشارات={len(entries)} score={score:.3f}"
        )
        await self._on_flush(symbol, score, len(entries))

    def _compute_score(self, entries: list[SignalEntry]) -> float:
        """
        BUY  → +confidence
        SELL → -confidence
        مثال: [BUY 0.8، SELL 0.3] → 0.8 - 0.3 = 0.5
        """
        score = 0.0
        for e in entries:
            if e.action == "BUY":
                score += e.confidence
            else:
                score -= e.confidence
        return score
