from __future__ import annotations

import numpy as np
import pandas as pd
import pytest

from indicators.rsi import calculate_rsi


def make_close(n: int = 100, seed: int = 42) -> pd.Series:
    rng = np.random.default_rng(seed)
    prices = 60000 + np.cumsum(rng.normal(0, 100, n))
    return pd.Series(prices, dtype=float)


def test_rsi_length_matches_input():
    close = make_close(100)
    rsi = calculate_rsi(close, 14)
    assert len(rsi) == len(close)


def test_rsi_range_0_to_100():
    close = make_close(200)
    rsi = calculate_rsi(close, 14)
    valid = rsi.dropna()
    assert (valid >= 0.0).all()
    assert (valid <= 100.0).all()


def test_rsi_constant_prices_is_nan_or_50():
    """أسعار ثابتة → delta=0 → gains=losses=0 → RSI يجب أن يكون NaN أو 50."""
    close = pd.Series([100.0] * 50)
    rsi = calculate_rsi(close, 14)
    valid = rsi.dropna()
    # إما كل القيم NaN أو كلها 50 (يعتمد على المكتبة)
    if len(valid) > 0:
        assert all(abs(v - 50.0) < 1.0 or np.isnan(v) for v in valid)


def test_rising_prices_give_high_rsi():
    """أسعار صاعدة باستمرار → RSI قريب من 100."""
    close = pd.Series(range(1, 101), dtype=float)
    rsi = calculate_rsi(close, 14)
    assert float(rsi.dropna().iloc[-1]) > 80.0


def test_falling_prices_give_low_rsi():
    """أسعار هابطة باستمرار → RSI قريب من 0."""
    close = pd.Series(range(100, 0, -1), dtype=float)
    rsi = calculate_rsi(close, 14)
    assert float(rsi.dropna().iloc[-1]) < 20.0
