import pandas as pd
import ta as ta_lib
import logging
from typing import List, Optional
from app.config import settings

logger = logging.getLogger(__name__)


class RiskEngine:
    """
    Calculates stop-loss, take-profit levels, and position size.

    Stop-loss is ATR-based (volatility-adjusted).
    Take-profit levels: TP1 = 1.5R, TP2 = 3R.
    Position size is determined by fixed % risk per trade.
    """

    def __init__(
        self,
        atr_multiplier: float = None,
        risk_pct: float = None,
        max_open_trades: int = None,
    ):
        self.atr_multiplier = atr_multiplier or settings.ATR_MULTIPLIER_SL
        self.risk_pct = risk_pct or settings.RISK_PER_TRADE_PCT
        self.max_open_trades = max_open_trades or settings.MAX_OPEN_TRADES

    def calculate_atr(self, df: pd.DataFrame, period: int = 14) -> float:
        """Return the latest ATR value from a DataFrame."""
        atr_series = ta_lib.volatility.AverageTrueRange(
            df["high"], df["low"], df["close"], window=period
        ).average_true_range()
        if atr_series is None or atr_series.dropna().empty:
            # Fallback: 1% of current price
            return float(df["close"].iloc[-1]) * 0.01
        return float(atr_series.dropna().iloc[-1])

    def calculate_stop_loss(
        self, entry_price: float, signal: str, atr: float
    ) -> float:
        """
        ATR-based stop-loss.
        BUY:  SL = entry - (ATR * multiplier)
        SELL: SL = entry + (ATR * multiplier)
        """
        offset = atr * self.atr_multiplier
        if signal == "BUY":
            return round(entry_price - offset, 8)
        else:
            return round(entry_price + offset, 8)

    def calculate_take_profits(
        self, entry_price: float, stop_loss: float, signal: str
    ) -> List[float]:
        """
        TP1 = 1.5R, TP2 = 3R where R = |entry - stop_loss|.
        """
        risk = abs(entry_price - stop_loss)
        if signal == "BUY":
            tp1 = round(entry_price + risk * 1.5, 8)
            tp2 = round(entry_price + risk * 3.0, 8)
        else:
            tp1 = round(entry_price - risk * 1.5, 8)
            tp2 = round(entry_price - risk * 3.0, 8)
        return [tp1, tp2]

    def calculate_position_size(
        self,
        capital: float,
        entry_price: float,
        stop_loss: float,
    ) -> float:
        """
        Position size based on fixed risk % of capital.
        quantity = (capital * risk_pct / 100) / |entry - stop_loss|
        """
        risk_amount = capital * self.risk_pct / 100
        price_risk = abs(entry_price - stop_loss)
        if price_risk == 0:
            return 0.0
        quantity = risk_amount / price_risk
        return round(quantity, 6)

    def can_open_trade(self, open_trade_count: int) -> bool:
        """Check if another trade can be opened given current open trade count."""
        return open_trade_count < self.max_open_trades


risk_engine = RiskEngine()
