import logging
from datetime import datetime
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession

from app.data_layer.market_data import get_current_price
from app.models.trade import Trade, TradeMode, TradeSignal, TradeStatus
from app.config import settings

logger = logging.getLogger(__name__)


class SimulationEngine:
    """
    Same interface as ExecutionEngine but records simulated trades only.
    Applies optional slippage to simulate realistic fill prices.
    """

    def __init__(self, slippage_pct: float = None):
        self.slippage_pct = slippage_pct if slippage_pct is not None else settings.SLIPPAGE_PCT

    async def open_trade(
        self,
        db: AsyncSession,
        pair: str,
        signal: str,
        entry_price: float,
        stop_loss: float,
        take_profit_levels: list,
        quantity: float,
        strategy_snapshot: dict,
    ) -> Optional[Trade]:
        """Record a simulated trade at the current market price with slippage."""
        live_price = get_current_price(pair) or entry_price

        # Apply slippage: unfavourable fill
        slippage = live_price * (self.slippage_pct / 100)
        if signal == "BUY":
            fill_price = live_price + slippage  # worse fill for buyer
        else:
            fill_price = live_price - slippage  # worse fill for seller

        trade = Trade(
            pair=pair,
            mode=TradeMode.SIMULATION,
            signal=TradeSignal(signal),
            entry_price=round(fill_price, 8),
            current_price=round(fill_price, 8),
            stop_loss=stop_loss,
            take_profit_levels=take_profit_levels,
            tp_hits=[False] * len(take_profit_levels),
            quantity=quantity,
            opened_at=datetime.utcnow(),
            status=TradeStatus.OPEN,
            strategy_snapshot=strategy_snapshot,
        )

        db.add(trade)
        await db.commit()
        await db.refresh(trade)
        logger.info(
            f"SIMULATION trade opened: {pair} {signal} @ {fill_price:.8f} "
            f"(slippage={slippage:.8f}, id={trade.id})"
        )
        return trade


simulation_engine = SimulationEngine()
