import pandas as pd
import logging
from typing import List, Dict, Any
from app.strategies.base_strategy import BaseStrategy
from app.analytics.metrics import (
    calculate_sharpe_ratio,
    calculate_max_drawdown,
    calculate_profit_factor,
    calculate_expectancy,
)

logger = logging.getLogger(__name__)


def run_backtest(
    strategy: BaseStrategy,
    data: Dict[str, pd.DataFrame],
    initial_capital: float = 1000.0,
    risk_pct: float = 1.0,
    sl_pct: float = 2.0,
    tp_pct: float = 4.0,
) -> Dict[str, Any]:
    """
    Simple event-driven backtest on historical OHLCV data.

    Uses the primary timeframe candle-by-candle.
    Stop-loss and take-profit are fixed percentages for simplicity.

    Args:
        strategy:        instance of BaseStrategy
        data:            multi-timeframe OHLCV dict
        initial_capital: starting capital in quote currency
        risk_pct:        % of capital to risk per trade
        sl_pct:          stop-loss percentage from entry
        tp_pct:          take-profit percentage from entry

    Returns:
        dict with equity curve, trade list, and summary metrics.
    """
    primary_tf = list(data.keys())[0]
    df = data[primary_tf].copy().reset_index()

    capital = initial_capital
    equity_curve = [capital]
    trade_log = []
    in_trade = False
    entry_price = 0.0
    trade_signal = ""
    sl_price = 0.0
    tp_price = 0.0

    for i in range(50, len(df)):  # warm-up 50 candles for indicators
        # Build a slice of data up to current candle
        slice_data = {tf: data[tf].iloc[:i] for tf in data}
        result = strategy.analyze(slice_data)

        close = float(df.loc[i, "close"])

        if not in_trade:
            if result.signal in ("BUY", "SELL") and result.confidence >= 50:
                in_trade = True
                entry_price = close
                trade_signal = result.signal

                if trade_signal == "BUY":
                    sl_price = entry_price * (1 - sl_pct / 100)
                    tp_price = entry_price * (1 + tp_pct / 100)
                else:
                    sl_price = entry_price * (1 + sl_pct / 100)
                    tp_price = entry_price * (1 - tp_pct / 100)
        else:
            # Check SL / TP
            hit_sl = (
                (trade_signal == "BUY" and close <= sl_price)
                or (trade_signal == "SELL" and close >= sl_price)
            )
            hit_tp = (
                (trade_signal == "BUY" and close >= tp_price)
                or (trade_signal == "SELL" and close <= tp_price)
            )

            if hit_sl or hit_tp:
                exit_price = sl_price if hit_sl else tp_price
                if trade_signal == "BUY":
                    pnl_pct = (exit_price - entry_price) / entry_price * 100
                else:
                    pnl_pct = (entry_price - exit_price) / entry_price * 100

                pnl_amount = capital * (risk_pct / 100) * (pnl_pct / sl_pct)
                capital += pnl_amount
                equity_curve.append(capital)

                trade_log.append({
                    "entry": entry_price,
                    "exit": exit_price,
                    "signal": trade_signal,
                    "pnl_pct": round(pnl_pct, 4),
                    "result": "WIN" if not hit_sl else "LOSS",
                })
                in_trade = False

    total = len(trade_log)
    wins = sum(1 for t in trade_log if t["result"] == "WIN")
    losses = total - wins
    win_rate = wins / total if total > 0 else 0.0
    pnls = [t["pnl_pct"] for t in trade_log]
    gross_profit = sum(p for p in pnls if p > 0)
    gross_loss = abs(sum(p for p in pnls if p < 0))
    avg_win = gross_profit / wins if wins > 0 else 0.0
    avg_loss = gross_loss / losses if losses > 0 else 0.0

    return {
        "strategy": strategy.get_name(),
        "initial_capital": initial_capital,
        "final_capital": round(capital, 2),
        "total_return_pct": round((capital - initial_capital) / initial_capital * 100, 4),
        "total_trades": total,
        "winning_trades": wins,
        "losing_trades": losses,
        "win_rate": round(win_rate * 100, 2),
        "max_drawdown": calculate_max_drawdown(equity_curve),
        "sharpe_ratio": calculate_sharpe_ratio([p / 100 for p in pnls]),
        "profit_factor": calculate_profit_factor(gross_profit, gross_loss),
        "expectancy": calculate_expectancy(win_rate, avg_win, avg_loss),
        "equity_curve": [round(v, 2) for v in equity_curve],
        "trades": trade_log,
    }
