import pandas as pd
import pandas_ta as ta
import logging
from typing import Dict
from app.strategies.base_strategy import BaseStrategy, StrategyResult

logger = logging.getLogger(__name__)

DEFAULT_PARAMS = {
    "fast_period": 9,
    "slow_period": 21,
    "timeframes": ["15m", "1h"],
}


class EMACrossoverStrategy(BaseStrategy):
    """
    EMA Crossover strategy with multi-timeframe confirmation.

    - BUY  (Golden Cross): fast EMA crosses above slow EMA.
    - SELL (Death Cross):  fast EMA crosses below slow EMA.
    - Confidence scales with the percentage divergence between the two EMAs.
    """

    def __init__(self, params: dict = None):
        merged = {**DEFAULT_PARAMS, **(params or {})}
        super().__init__(name="EMA_Crossover", params=merged)

    def _compute_emas(self, df: pd.DataFrame) -> tuple[float, float, float, float]:
        """
        Return (fast_prev, fast_curr, slow_prev, slow_curr) EMA values.
        """
        fast = self.params["fast_period"]
        slow = self.params["slow_period"]
        fast_ema = ta.ema(df["close"], length=fast)
        slow_ema = ta.ema(df["close"], length=slow)

        if fast_ema is None or slow_ema is None:
            return 0.0, 0.0, 0.0, 0.0

        fast_vals = fast_ema.dropna()
        slow_vals = slow_ema.dropna()

        if len(fast_vals) < 2 or len(slow_vals) < 2:
            return 0.0, 0.0, 0.0, 0.0

        return (
            float(fast_vals.iloc[-2]),
            float(fast_vals.iloc[-1]),
            float(slow_vals.iloc[-2]),
            float(slow_vals.iloc[-1]),
        )

    def _crossover_signal(
        self, fast_prev: float, fast_curr: float, slow_prev: float, slow_curr: float
    ) -> tuple[str, float]:
        """Detect crossover direction and compute confidence from divergence."""
        if fast_curr == 0 or slow_curr == 0:
            return "HOLD", 0.0

        # Golden cross: fast crossed above slow
        if fast_prev <= slow_prev and fast_curr > slow_curr:
            divergence = abs(fast_curr - slow_curr) / slow_curr * 100
            confidence = min(100.0, divergence * 20)  # scale: 5% div = 100 conf
            return "BUY", confidence

        # Death cross: fast crossed below slow
        if fast_prev >= slow_prev and fast_curr < slow_curr:
            divergence = abs(slow_curr - fast_curr) / slow_curr * 100
            confidence = min(100.0, divergence * 20)
            return "SELL", confidence

        # No crossover — check if EMAs are already diverged (trend continuation)
        if fast_curr > slow_curr:
            divergence = (fast_curr - slow_curr) / slow_curr * 100
            confidence = min(60.0, divergence * 10)  # weaker signal without cross
            return "BUY", confidence
        elif fast_curr < slow_curr:
            divergence = (slow_curr - fast_curr) / slow_curr * 100
            confidence = min(60.0, divergence * 10)
            return "SELL", confidence

        return "HOLD", 0.0

    def analyze(self, data: Dict[str, pd.DataFrame]) -> StrategyResult:
        timeframes = self.params["timeframes"]
        primary_tf = timeframes[0]
        confirm_tf = timeframes[1] if len(timeframes) > 1 else None

        if primary_tf not in data:
            logger.warning(f"EMACrossover: primary timeframe {primary_tf} not in data")
            return StrategyResult(signal="HOLD", confidence=0.0)

        fp, fc, sp, sc = self._compute_emas(data[primary_tf])
        primary_signal, primary_conf = self._crossover_signal(fp, fc, sp, sc)

        confirm_signal = "HOLD"
        confirm_meta = {}
        if confirm_tf and confirm_tf in data:
            cfp, cfc, csp, csc = self._compute_emas(data[confirm_tf])
            confirm_signal, _ = self._crossover_signal(cfp, cfc, csp, csc)
            confirm_meta = {
                "confirm_fast_ema": round(cfc, 6),
                "confirm_slow_ema": round(csc, 6),
                "confirm_timeframe": confirm_tf,
            }

        # Require confirmation timeframe agreement
        if confirm_tf and confirm_signal not in (primary_signal, "HOLD"):
            final_signal = "HOLD"
            final_confidence = 0.0
        else:
            final_signal = primary_signal
            boost = 15.0 if confirm_signal == primary_signal else 0.0
            final_confidence = min(100.0, primary_conf + boost)

        return StrategyResult(
            signal=final_signal,
            confidence=final_confidence,
            metadata={
                "primary_fast_ema": round(fc, 6),
                "primary_slow_ema": round(sc, 6),
                "primary_timeframe": primary_tf,
                **confirm_meta,
            },
        )
