import logging
from datetime import datetime
from typing import List
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from app.data_layer.market_data import get_current_price
from app.models.trade import Trade, TradeStatus, TradeResult, TradeSignal

logger = logging.getLogger(__name__)


class TradeMonitor:
    """
    Monitors open trades every 30 seconds.
    Checks stop-loss hits, take-profit hits (partial), and updates max drawdown.
    """

    def __init__(self):
        # Will be set by main.py after telegram_service is initialised
        self.on_sl_hit = None   # async callback(trade)
        self.on_tp_hit = None   # async callback(trade, tp_index)
        self.on_trade_closed = None  # async callback(trade)

    async def check_open_trades(self, db: AsyncSession):
        """Called by the scheduler every 30 seconds."""
        result = await db.execute(
            select(Trade).where(Trade.status == TradeStatus.OPEN)
        )
        open_trades: List[Trade] = result.scalars().all()

        for trade in open_trades:
            await self._process_trade(db, trade)

    async def _process_trade(self, db: AsyncSession, trade: Trade):
        """Evaluate a single open trade against the current market price."""
        current_price = get_current_price(trade.pair)
        if current_price is None:
            logger.warning(f"Could not fetch price for {trade.pair}, skipping trade {trade.id}")
            return

        trade.current_price = current_price

        # Update max drawdown
        if trade.signal == TradeSignal.BUY:
            unrealised_pnl = (current_price - trade.entry_price) / trade.entry_price * 100
        else:
            unrealised_pnl = (trade.entry_price - current_price) / trade.entry_price * 100

        if unrealised_pnl < 0:
            drawdown = abs(unrealised_pnl)
            if trade.max_drawdown is None or drawdown > trade.max_drawdown:
                trade.max_drawdown = round(drawdown, 4)

        # Check stop-loss
        if self._hit_stop_loss(trade, current_price):
            await self._close_trade(db, trade, current_price, TradeResult.LOSS)
            if self.on_sl_hit:
                await self.on_sl_hit(trade)
            return

        # Check take-profit levels (partial)
        tp_hits = list(trade.tp_hits or [False] * len(trade.take_profit_levels or []))
        changed = False
        for i, tp_price in enumerate(trade.take_profit_levels or []):
            if not tp_hits[i] and self._hit_take_profit(trade, current_price, tp_price):
                tp_hits[i] = True
                changed = True
                logger.info(f"Trade {trade.id}: TP{i+1} hit at {current_price}")
                if self.on_tp_hit:
                    await self.on_tp_hit(trade, i)

        if changed:
            trade.tp_hits = tp_hits
            # Close trade when all TPs are hit
            if all(tp_hits):
                await self._close_trade(db, trade, current_price, TradeResult.WIN)
                if self.on_trade_closed:
                    await self.on_trade_closed(trade)
                return

        await db.commit()

    def _hit_stop_loss(self, trade: Trade, current_price: float) -> bool:
        if trade.signal == TradeSignal.BUY:
            return current_price <= trade.stop_loss
        else:
            return current_price >= trade.stop_loss

    def _hit_take_profit(self, trade: Trade, current_price: float, tp_price: float) -> bool:
        if trade.signal == TradeSignal.BUY:
            return current_price >= tp_price
        else:
            return current_price <= tp_price

    async def _close_trade(
        self,
        db: AsyncSession,
        trade: Trade,
        exit_price: float,
        result: TradeResult,
    ):
        """Mark a trade as closed and compute final PnL."""
        trade.status = TradeStatus.CLOSED
        trade.closed_at = datetime.utcnow()
        trade.result = result
        trade.current_price = exit_price

        if trade.signal == TradeSignal.BUY:
            pnl_pct = (exit_price - trade.entry_price) / trade.entry_price * 100
        else:
            pnl_pct = (trade.entry_price - exit_price) / trade.entry_price * 100

        trade.pnl_percentage = round(pnl_pct, 4)
        if trade.quantity:
            trade.pnl_amount = round(pnl_pct / 100 * trade.entry_price * trade.quantity, 8)

        await db.commit()
        logger.info(
            f"Trade {trade.id} closed: {result} | PnL={pnl_pct:.2f}% | exit={exit_price}"
        )


trade_monitor = TradeMonitor()
