import math
from mathlib import add, subtract, multiply, divide, factorial, exp, pow, root_n, log, ln 
from errors import ZeroDivisionMathError, DomainError, ConvergenceError
import pytest 

#======================================== ADD =============================================

@pytest.mark.parametrize("x, y, expected",
    [
        (1, 1, 2), (-1, 1, 0), (5, 3, 8),
        (5, -3, 2), (-5, 3, -2), (-5, -3, -8),
        (0, 0, 0), (0, 7, 7), (7, 0, 7),
        (1000000, 999999, 1999999), (-1000000, -999999, -1999999),
        (0.1, 0.2, 0.3), (-0.1, 0.2, 0.1), (7.156, -0.2679, 6.8881),
        (-10.678, -12.00001, -22.67801), (1e-10, 2e-10, 3e-10), (1e10, 2e10, 3e10)
    ]) 

def test_add_correctness(x, y, expected):
    assert add(x, y) == pytest.approx(expected, rel=1e-5)

@pytest.mark.parametrize("x, y", [(2, 3), (0, 5), (-4, 7), (0.5, 1.5)])
def test_add_commutative(x, y):
    assert add(x, y) == add(y, x)
    assert add(x, 0) == x
    assert add(0, x) == x

@pytest.mark.parametrize("x, y, z", [(1, 2, 3), (0, 0, 0), (-1, 4, -2)])
def test_add_associative(x, y, z):
    assert add(add(x, y), z) == add(x, add(y, z))

@pytest.mark.parametrize("k, x, y", [(2, 3, 4), (-5, 1, -2), (0, 10, 20)])
def test_add_distributive_over_multiply(k, x, y):
    assert k * add(x, y) == add(k * x, k * y)

#======================================== SUBTRACT ===========================================

@pytest.mark.parametrize("x, y, expected",
    [
        (1, 1, 0), (-1, 1, -2), (5, 3, 2),
        (5, -3, 8), (-5, 3, -8), (-5, -3, -2),
        (0, 0, 0), (7, 0, 7), (0, 7, -7),
        (1000000, 1, 999999), (-1000000, -1, -999999),
        (0.1, 0.2, -0.1), (-0.1, 0.2, -0.3), (7.156, -0.2679, 7.4239),
        (-10.678, -12.00001, 1.32201), (1e-9, 2e-9, -1e-9), (1e12, 1e11, 9e11)
    ]) 

def test_subtract_correctness(x, y, expected):
    assert subtract(x, y) == pytest.approx(expected, rel=1e-5)

@pytest.mark.parametrize("x, y", [(5, 3), (0, 7), (-4, 2), (1.2, 3.4)])
def test_subtract_not_commutative(x, y):
    assert subtract(x, y) != subtract(y, x)

@pytest.mark.parametrize("x, y, z", [(10, 4, 2), (5, 0, 3)])
def test_subtract_not_associative(x, y, z):
    assert subtract(subtract(x, y), z) != subtract(x, subtract(y, z))

@pytest.mark.parametrize("k, x, y", [(3, 8, 5), (-2, 10, -6)])
def test_subtract_distributive_over_multiply(k, x, y):
    assert k * subtract(x, y) == subtract(k * x, k * y)

@pytest.mark.parametrize("x", [0, 42, -7, 0.001, 1e-8, -1e5])
def test_subtract_self_is_zero(x):
    assert subtract(x, x) == pytest.approx(0, abs=1e-5)

#======================================== MULTIPLY ===========================================

@pytest.mark.parametrize("x, y, expected",
    [
        (1, 1, 1), (-1, 1, -1), (5, 3, 15),
        (5, -3, -15), (-5, 3, -15), (-5, -3, 15),
        (0, 0, 0), (0, 42, 0), (777, 0, 0),
        (10000, 9999, 99990000), (-10000, -9999, 99990000),
        (0.1, 0.2, 0.02), (-0.1, 0.2, -0.02),
        (7.156, -0.2679, -1.9170924), (-10.678, -12.00001, 128.13610678),
        (1e-8, 1e-7, 1e-15), (1e6, 1e5, 1e11),
    ])

def test_multiply_correctness(x, y, expected):
    assert multiply(x, y) == pytest.approx(expected, rel=1e-5)

@pytest.mark.parametrize("x, y", [(2, 3), (0, 5), (-4, 7), (0.5, 1.5)])
def test_multiply_commutative(x, y):
    assert multiply(y, x) == multiply(x, y)

@pytest.mark.parametrize("x, y, z", [(1, 2, 3), (3, 7, -1), (-1, 4, -2)])
def test_multiply_associative(x, y, z):
    assert multiply(multiply(x, y), z) == multiply(x, multiply(y, z))

@pytest.mark.parametrize("k, x, y", [(2, 3, 4), (-5, 1, -2), (9, 10, 20)])
def test_multiply_distributive_over_add(k, x, y):
    assert multiply(k, (x + y)) == (multiply(k, x) + multiply(k, y))

@pytest.mark.parametrize("x", [2, 3, 4,-534, 101, -209, 9, 10, 20])
def test_multiply_zero_is_zero(x):
    assert multiply(x, 0) == pytest.approx(0.0, rel=1e-5)

#======================================== DIVIDE ===========================================

@pytest.mark.parametrize("x, y, expected",
    [
        (1, 1, 1), (-1, 1, -1), (15, 3, 5),
        (30, -3, -10), (-24, 6, -4), (7, 2, 3.5),
        (10, -2, -5), (-10, -2, 5),
        (0.5, 2, 0.25),(1, 4, 0.25), (7.5, 2.5, 3.0), 
        (0.0001, 0.00002, 5.0), (1e12, 1e6, 1e6), (1e-10, 2e-5, 5e-6)
    ])

def test_divide_correctness(x, y, expected):
    assert divide(x, y) == pytest.approx(expected, rel=1e-5)

def test_divide_zero():
    with pytest.raises(ZeroDivisionMathError):
        divide(1, 0)
    with pytest.raises(ZeroDivisionMathError):
        divide(-42, 0)
    with pytest.raises(ZeroDivisionMathError):
        divide(0, 0)

@pytest.mark.parametrize("x, y", [(6, 2), (-8, 4), (1.5, 0.5), (10, -5)])
def test_divide_not_commutative(x, y):
        assert divide(x, y) != divide(y, x)

@pytest.mark.parametrize("x", [8, 2, 1, -4.5, 1.5])
def test_divide_by_one_is_identity(x):
    assert divide(x, 1) == pytest.approx(x)

@pytest.mark.parametrize("x", [1, -1, 3.14, -2.718, 1e-5])
def test_divide_self_by_self_is_one(x):
    assert divide(x, x) == pytest.approx(1.0, rel=1e-9, abs=1e-14)

@pytest.mark.parametrize("x, y", [(10, 2), (100, 5), (-24, -3), (0.8, 0.4)])
def test_divide_then_multiply(x, y):
    assert divide(x, y) * y == pytest.approx(x, rel=1e-9, abs=1e-14)

#======================================== FACTORIAL ===========================================

@pytest.mark.parametrize("x, expected",
    [
        (0, 1), (1, 1), (2, 2),
        (3, 6), (4, 24), (5, 120),
        (12, 479001600), (15, math.factorial(15)), (20, math.factorial(20))
    ])

def test_factorial_correctness(x, expected):
    assert factorial(x) == pytest.approx(expected, rel=1e-5) 

@pytest.mark.parametrize("x", [-1, -2, -109, -1000, -466543])
def test_factorial_negative(x):
    with pytest.raises(DomainError):
        factorial(x)

@pytest.mark.parametrize("x", [1.666666667, 2.5, 109.8, 1000.3253, 2.33333])
def test_factorial_float(x):
    with pytest.raises(TypeError):
        factorial(x)

#======================================== POW ===========================================

@pytest.mark.parametrize("x, y, expected", 
    [
        (2, 0, 1), (2, 2, 4), (2, 3, 8), (5, 1, 5),
        (1, -1, 1), (2, -1, 0.5), (2, -3, 0.125), (5, -4, 0.0016),
        (4, 0.5, 2.0), (9, 0.5, 3.0), (8, 1/3, 2.0), (16, 0.25, 2.0),
    ])

def test_pow_correctness(x, y, expected):
    assert pow(x, y) == pytest.approx(expected, rel=1e-5)

def test_pow_zero_base_zero_exponent():
    with pytest.raises(DomainError):
        pow(0, 0)

@pytest.mark.parametrize("x", [-8, -2, -1, -4.5, -1.5])
def test_pow_zero_base_negative_exponent(x):
    with pytest.raises(DomainError):
        pow(0, x)

@pytest.mark.parametrize("x, y", [(-1, 1.666666667), (-3, 2.5), (-23, -109.8), (-10, 1000.3253), (-0.3, 2.33333)])
def test_pow_negative_base_non_integer_exponent(x, y):
    with pytest.raises(DomainError):
        pow(x, y)

def test_pow_base_one_identity():
    for exp in range(-10, 11):
        assert pow(1, exp) == 1

def test_pow_base_zero_identity():
    for exp in range(1, 11):
        assert pow(0, exp) == 0

@pytest.mark.parametrize("k, x, y", [(2, 2, 1), (-3, 4, 1), (10, 10, -3), (1.7, 0.3, 1.1)])
def test_pow_product_property(k, x, y):
    assert pow(k, x + y) == pytest.approx(pow(k, x) * pow(k, y), rel=1e-5)

@pytest.mark.parametrize("k, x, y", [(2, 2, 1), (-3, 4, 1), (10, 10, -3), (1.7, 0.3, 1.1)])
def test_pow_quotient_property(k, x, y):
    assert pow(k, x - y) == pytest.approx(pow(k, x) / pow(k, y), rel=1e-5)

#======================================== EXPONENT ===========================================

@pytest.mark.parametrize("x, expected",
    [
        (0, 1.0),
        (1, math.e),
        (-1, 1/math.e),
        (0.5, math.sqrt(math.e)),
        (2, math.e**2),
        (-3, math.pow(math.e, -3)),
        (-5/9, math.pow(math.e, -5/9))
    ])

def test_exp_correctness(x, expected):
    assert exp(x) == pytest.approx(expected, rel=1e-9)

@pytest.mark.parametrize("x, y", [(2, 2), (-3, 4), (10, 5), (1.7, 0.3)])
def test_exp_product_property(x, y):
    assert exp(x + y) == pytest.approx(exp(x) * exp(y), rel=1e-5)

@pytest.mark.parametrize("x, y", [(2, 2), (-3, 4), (10, 5), (1.7, 0.3)])
def test_exp_quotient_property(x, y):
    assert exp(x - y) == pytest.approx(exp(x) / exp(y), rel=1e-5)

def test_exp_convergence_error_possible():
    with pytest.raises(ConvergenceError):
        exp(2, eps=1e-50, max_iter=1)

#======================================== ROOT_N ===========================================

@pytest.mark.parametrize("x, y, expected",
    [
        (16, 2, 4), (81, 4, 3), (27, 3, 3),
        (-8, 3, -2), (-32, 5, -2),
        (0, 7, 0), (1, 10, 1),
    ])

def test_root_n_correctness(x, y, expected):
    assert root_n(x, y) == pytest.approx(expected, rel=1e-9)

def test_root_n_negative_base_even_exponent():
    with pytest.raises(DomainError):
        root_n(-4, 2)

def test_root_n_zero_exponent():
    with pytest.raises(DomainError):
        root_n(4, 0)

@pytest.mark.parametrize("x, y, expected", [(-8, 3, -2), (-27, 3, -3)])
def test_root_n_negative_odd_root(x, y, expected):
    assert root_n(x, y) == pytest.approx(expected, rel=1e-9)

@pytest.mark.parametrize("x, y, exponent", [(6, 36, 3), (3.4, 189.1, 2.3), (16, 4, 4), (27, 3, 3)])
def test_root_n_product_property(x, y, exponent):
    assert root_n(x * y, exponent) == pytest.approx(root_n(x, exponent) * root_n(y, exponent), rel=1e-9)

@pytest.mark.parametrize("x, y, exponent", [(216, 36, 3), (3.4, 189.1, 2.3), (16, 4, 4), (27, 3, 3)])
def test_root_n_quotient_property(x, y, exponent):
    assert root_n(x / y, exponent) == pytest.approx(root_n(x, exponent) / root_n(y, exponent), rel=1e-5)

def test_root_n_convergence_error_possible():
    with pytest.raises(ConvergenceError):
        root_n(2, 2, eps=1e-50, max_iter=1)

#======================================== LOGARITHM ===========================================

@pytest.mark.parametrize("x, y, expected",
    [
        (8,    2,   3.0), (9,    3,   2.0), (16,   2,   4.0),
        (81,   3,   4.0), (100,  10,  2.0), (1000, 10,  3.0),
        (1,    10,  0.0), (1,    2,   0.0), (1,    5,   0.0)
    ])

def test_log_correctness(x, y, expected):
    assert log(x, y) == pytest.approx(expected, rel=1e-9)

@pytest.mark.parametrize("x, y, base", [(8, 8, 8), (5, 25, 5), (3, 4, 9), (2.5, 3.4, 1.2)])
def test_log_product_rule(x, y, base):
    assert log(x * y, base) == pytest.approx(log(x, base) + log(y, base), rel=1e-9)

@pytest.mark.parametrize("x, y, base", [(36, 3, 8), (5, 25, 5), (3, 4, 9), (2.5, 3.4, 1.2), (8, 4, 2)])
def test_log_quotient_rule(x, y, base):
    assert log(x / y, base) == pytest.approx(log(x, base) - log(y, base), rel=1e-9)

@pytest.mark.parametrize("x, y, base", [(2, 8, 2), (3, 3.5, 2), (0.2, 4.5, 3)])
def test_log_power_rule(x, y, base):
    assert log(math.pow(x, y), base) == pytest.approx(y * log(x, base), rel=1e-9)

@pytest.mark.parametrize("x", [8, 3.5, 2, 0.2, 4.5, 3])
def test_log_base_rule(x):
    assert log(x, x) == pytest.approx(1.0, rel=1e-9) 

@pytest.mark.parametrize("x", [8, 3.5, 2, 0.2, 4.5, 3])
def test_log_of_1_rule(x):
    assert log(1, x) == pytest.approx(0.0, rel=1e-9) 

@pytest.mark.parametrize("x, base1, base2", [(9, 3, 2), (64, 4, 8), (128, 2, 32), (104.8, 23.4, 7.0)])
def test_log_change_base(x, base1, base2):
    assert log(x, base1) == pytest.approx(log(x, base2) / log(base1, base2), rel=1e-5)

@pytest.mark.parametrize("x, y", [(3, 0), (3, -1), (3, 1), (10, -2)])
def test_log_invalid_base(x, y):
    with pytest.raises(DomainError):
        log(x, y)

@pytest.mark.parametrize("x, y", [(0, 3), (-3, -1), (-3, 1), (-10, 29)])
def test_log_non_positive_argument(x, y):
    with pytest.raises(DomainError):
        log(0, 3)
    with pytest.raises(DomainError):
        log(-1, 3) 

#======================================== NATURAL LOGARITHM ===========================================

@pytest.mark.parametrize("x, expected",
    [(1, 0.0), (math.e, 1.0), (math.e**2, 2.0), (0.5, math.log(0.5, math.e))]
)

def test_ln_correctness(x, expected):
    assert ln(x) == pytest.approx(expected, rel=1e-9)

@pytest.mark.parametrize("x", [0, -1, -2, -2.5, -100])
def test_ln_non_positive_argument(x):
    with pytest.raises(DomainError):
        ln(x)

@pytest.mark.parametrize("x, y", [(8, 8), (5, 25), (3, 4), (2.5, 3.4)])
def test_ln_product_rule(x, y):
    assert ln(x * y) == pytest.approx(ln(y) + ln(x), rel=1e-9)

@pytest.mark.parametrize("x, y", [(36, 3), (8, 4), (6942, 89), (182.6, 23.4)])
def test_ln_quotient_rule(x, y):
    assert ln(x / y) == pytest.approx(ln(x) - ln(y))

@pytest.mark.parametrize("x, y", [(2, 8), (3, 3.5), (9, 4.5)])
def test_ln_power_rule(x, y):
    assert ln(math.pow(x, y)) == pytest.approx(y * ln(x))

def test_ln_base_rule(): 
    assert ln(math.e) == pytest.approx(1.0, rel=1e-5) 

def test_ln_of_1_rule(): 
    assert ln(1) == pytest.approx(0.0, rel=1e-5)

def test_ln_convergence_error_possible():
    with pytest.raises(ConvergenceError):
        exp(2, eps=1e-50, max_iter=1)