import logging
from datetime import datetime, timedelta
from typing import List, Dict, Any
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession

from app.models.trade import Trade, TradeStatus, TradeResult
from app.analytics.metrics import (
    calculate_sharpe_ratio,
    calculate_max_drawdown,
    calculate_profit_factor,
    calculate_expectancy,
    calculate_rr_ratio,
)

logger = logging.getLogger(__name__)


class PerformanceEngine:
    """Aggregates trading statistics from closed trades."""

    async def get_global_stats(self, db: AsyncSession) -> Dict[str, Any]:
        """Return overall bot performance metrics."""
        result = await db.execute(
            select(Trade).where(Trade.status == TradeStatus.CLOSED)
        )
        trades: List[Trade] = result.scalars().all()

        if not trades:
            return self._empty_stats()

        total = len(trades)
        wins = sum(1 for t in trades if t.result == TradeResult.WIN)
        losses = sum(1 for t in trades if t.result == TradeResult.LOSS)
        win_rate = wins / total if total > 0 else 0.0

        pnls = [t.pnl_percentage for t in trades if t.pnl_percentage is not None]
        total_pnl = sum(pnls)

        gross_profit = sum(p for p in pnls if p > 0)
        gross_loss = abs(sum(p for p in pnls if p < 0))
        profit_factor = calculate_profit_factor(gross_profit, gross_loss)

        avg_win = (gross_profit / wins) if wins > 0 else 0.0
        avg_loss = (gross_loss / losses) if losses > 0 else 0.0
        expectancy = calculate_expectancy(win_rate, avg_win, avg_loss)

        # Build a simple equity curve (cumulative PnL)
        equity_curve = []
        running = 0.0
        for t in sorted(trades, key=lambda x: x.closed_at or datetime.utcnow()):
            running += t.pnl_percentage or 0.0
            equity_curve.append(running)

        max_dd = calculate_max_drawdown(equity_curve) if equity_curve else 0.0
        sharpe = calculate_sharpe_ratio([t.pnl_percentage / 100 for t in trades if t.pnl_percentage])

        # Average risk-reward ratio
        rr_ratios = []
        for t in trades:
            if t.take_profit_levels and len(t.take_profit_levels) > 0:
                rr = calculate_rr_ratio(t.entry_price, t.stop_loss, t.take_profit_levels[0])
                if rr > 0:
                    rr_ratios.append(rr)
        avg_rr = sum(rr_ratios) / len(rr_ratios) if rr_ratios else 0.0

        return {
            "total_trades": total,
            "winning_trades": wins,
            "losing_trades": losses,
            "win_rate": round(win_rate * 100, 2),
            "total_pnl": round(total_pnl, 4),
            "max_drawdown": round(max_dd, 4),
            "avg_rr_ratio": round(avg_rr, 4),
            "sharpe_ratio": sharpe,
            "profit_factor": profit_factor,
            "expectancy": expectancy,
        }

    async def get_strategy_stats(self, db: AsyncSession) -> List[Dict[str, Any]]:
        """Return per-strategy performance breakdown."""
        result = await db.execute(
            select(Trade).where(Trade.status == TradeStatus.CLOSED)
        )
        trades: List[Trade] = result.scalars().all()

        strategy_map: Dict[str, List[Trade]] = {}
        for t in trades:
            name = (t.strategy_snapshot or {}).get("strategy", "unknown")
            strategy_map.setdefault(name, []).append(t)

        stats = []
        for name, strades in strategy_map.items():
            total = len(strades)
            wins = sum(1 for t in strades if t.result == TradeResult.WIN)
            pnls = [t.pnl_percentage for t in strades if t.pnl_percentage is not None]
            stats.append({
                "strategy": name,
                "total_trades": total,
                "winning_trades": wins,
                "win_rate": round(wins / total * 100, 2) if total > 0 else 0.0,
                "total_pnl": round(sum(pnls), 4),
            })

        return sorted(stats, key=lambda x: x["total_pnl"], reverse=True)

    async def get_pair_stats(self, db: AsyncSession) -> List[Dict[str, Any]]:
        """Return per-pair performance breakdown."""
        result = await db.execute(
            select(Trade).where(Trade.status == TradeStatus.CLOSED)
        )
        trades: List[Trade] = result.scalars().all()

        pair_map: Dict[str, List[Trade]] = {}
        for t in trades:
            pair_map.setdefault(t.pair, []).append(t)

        stats = []
        for pair, ptrades in pair_map.items():
            total = len(ptrades)
            wins = sum(1 for t in ptrades if t.result == TradeResult.WIN)
            pnls = [t.pnl_percentage for t in ptrades if t.pnl_percentage is not None]
            stats.append({
                "pair": pair,
                "total_trades": total,
                "winning_trades": wins,
                "win_rate": round(wins / total * 100, 2) if total > 0 else 0.0,
                "total_pnl": round(sum(pnls), 4),
            })

        return sorted(stats, key=lambda x: x["total_pnl"], reverse=True)

    async def get_time_stats(
        self, db: AsyncSession, period: str = "DAILY"
    ) -> List[Dict[str, Any]]:
        """
        Return performance bucketed by time period.
        period: DAILY | WEEKLY | MONTHLY
        """
        result = await db.execute(
            select(Trade).where(Trade.status == TradeStatus.CLOSED)
        )
        trades: List[Trade] = result.scalars().all()

        def bucket_key(trade: Trade) -> str:
            dt = trade.closed_at or datetime.utcnow()
            if period == "DAILY":
                return dt.strftime("%Y-%m-%d")
            elif period == "WEEKLY":
                # ISO week
                return f"{dt.isocalendar()[0]}-W{dt.isocalendar()[1]:02d}"
            else:  # MONTHLY
                return dt.strftime("%Y-%m")

        bucket_map: Dict[str, List[Trade]] = {}
        for t in trades:
            key = bucket_key(t)
            bucket_map.setdefault(key, []).append(t)

        stats = []
        for period_key, ptrades in sorted(bucket_map.items()):
            total = len(ptrades)
            wins = sum(1 for t in ptrades if t.result == TradeResult.WIN)
            pnls = [t.pnl_percentage for t in ptrades if t.pnl_percentage is not None]
            stats.append({
                "period": period_key,
                "total_trades": total,
                "winning_trades": wins,
                "win_rate": round(wins / total * 100, 2) if total > 0 else 0.0,
                "total_pnl": round(sum(pnls), 4),
            })

        return stats

    def _empty_stats(self) -> Dict[str, Any]:
        return {
            "total_trades": 0,
            "winning_trades": 0,
            "losing_trades": 0,
            "win_rate": 0.0,
            "total_pnl": 0.0,
            "max_drawdown": 0.0,
            "avg_rr_ratio": 0.0,
            "sharpe_ratio": 0.0,
            "profit_factor": 0.0,
            "expectancy": 0.0,
        }


performance_engine = PerformanceEngine()
