from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from app.database import get_db
from app.models.trade import Trade, TradeStatus, TradeMode, TradeResult
from app.engines.trade_monitor import trade_monitor
from app.data_layer.market_data import get_current_price
from datetime import datetime

router = APIRouter(prefix="/api/v1/trades", tags=["trades"])


@router.get("")
async def list_trades(
    status: Optional[str] = Query(None, description="OPEN or CLOSED"),
    mode: Optional[str] = Query(None, description="REAL or SIMULATION"),
    pair: Optional[str] = Query(None),
    limit: int = Query(50, le=500),
    offset: int = Query(0, ge=0),
    db: AsyncSession = Depends(get_db),
):
    """List trades with optional filters."""
    stmt = select(Trade)
    if status:
        stmt = stmt.where(Trade.status == TradeStatus(status.upper()))
    if mode:
        stmt = stmt.where(Trade.mode == TradeMode(mode.upper()))
    if pair:
        stmt = stmt.where(Trade.pair == pair.upper())
    stmt = stmt.order_by(Trade.opened_at.desc()).offset(offset).limit(limit)

    result = await db.execute(stmt)
    trades = result.scalars().all()
    return {"trades": [t.to_dict() for t in trades], "count": len(trades)}


@router.get("/{trade_id}")
async def get_trade(trade_id: int, db: AsyncSession = Depends(get_db)):
    """Fetch a single trade by ID."""
    result = await db.execute(select(Trade).where(Trade.id == trade_id))
    trade = result.scalar_one_or_none()
    if trade is None:
        raise HTTPException(status_code=404, detail="Trade not found")
    return trade.to_dict()


@router.post("/{trade_id}/close")
async def close_trade_manually(trade_id: int, db: AsyncSession = Depends(get_db)):
    """Manually close an open trade at the current market price."""
    result = await db.execute(select(Trade).where(Trade.id == trade_id))
    trade = result.scalar_one_or_none()
    if trade is None:
        raise HTTPException(status_code=404, detail="Trade not found")
    if trade.status != TradeStatus.OPEN:
        raise HTTPException(status_code=400, detail="Trade is not open")

    current_price = get_current_price(trade.pair)
    if not current_price:
        raise HTTPException(status_code=503, detail="Could not fetch current price")

    # Determine result
    if trade.signal == "BUY":
        pnl_pct = (current_price - trade.entry_price) / trade.entry_price * 100
    else:
        pnl_pct = (trade.entry_price - current_price) / trade.entry_price * 100

    if pnl_pct > 0:
        trade_result = TradeResult.WIN
    elif pnl_pct < 0:
        trade_result = TradeResult.LOSS
    else:
        trade_result = TradeResult.BREAKEVEN

    trade.status = TradeStatus.CLOSED
    trade.closed_at = datetime.utcnow()
    trade.result = trade_result
    trade.current_price = current_price
    trade.pnl_percentage = round(pnl_pct, 4)
    if trade.quantity:
        trade.pnl_amount = round(pnl_pct / 100 * trade.entry_price * trade.quantity, 8)

    await db.commit()
    await db.refresh(trade)
    return {"message": "Trade closed", "trade": trade.to_dict()}
