diff --git a/src/stamojo/special/__init__.mojo b/src/stamojo/special/__init__.mojo index f162cf6..4a92531 100644 --- a/src/stamojo/special/__init__.mojo +++ b/src/stamojo/special/__init__.mojo @@ -15,14 +15,17 @@ Functions provided: - Inverse error function (erfinv) - Log-beta function - Beta function +- Bessel functions of the first and second kind (j0, j1, jn, y0, y1) +- Modified Bessel functions and scaled variants (i0, i1, i0e, i1e) The Mojo standard library already provides erf, erfc, gamma, and lgamma, so we do not reimplement those here. -The modules of the subpackages are named with a leading underscore +The modules of the subpackages are named with a leading underscore (e.g., `_gamma`) to avoid conflicts with the standard library functions. """ from ._gamma import gammainc, gammaincc from ._beta import beta, lbeta, betainc from ._erf import erfinv, ndtri +from ._bessel import j0, j1, jn, i0, i1, i0e, i1e, y0, y1 diff --git a/src/stamojo/special/_bessel.mojo b/src/stamojo/special/_bessel.mojo new file mode 100644 index 0000000..70eb993 --- /dev/null +++ b/src/stamojo/special/_bessel.mojo @@ -0,0 +1,373 @@ +# ===----------------------------------------------------------------------=== # +# Stamojo - Special - Bessel functions +# Licensed under Apache 2.0 +# ===----------------------------------------------------------------------=== # +"""Bessel functions. + +This module provides implementations of Bessel functions of the first and +second kind, as well as modified Bessel functions and their exponentially +scaled variants. + +Functions: + - j0, j1, jn: Bessel functions of the first kind (orders 0, 1, n) + - i0, i1, i0e, i1e: Modified Bessel functions of the first kind + - y0, y1: Bessel functions of the second kind (orders 0, 1) + +References: + - https://en.wikipedia.org/wiki/Bessel_function +""" + +from math import cos, exp, inf, log, nan, sin, sqrt + +# === --------------------------------------------------------------------=== # +# General notes: +# TODO: Asymptotic expansions need to be implemented for large arguments to +# ensure accuracy and efficiency. The threshold for switching to asymptotic +# expansions should be determined empirically based on accuracy requirements. +# === ----------------------------------------------------------------------=== # + +# ===----------------------------------------------------------------------=== # +# Constants +# ===----------------------------------------------------------------------=== # + +comptime _MAX_SERIES_ITER: Int = 50 +comptime _PI: Float64 = 3.141592653589793 +comptime _PI_INV: Float64 = 1.0 / _PI +comptime _EULER_GAMMA: Float64 = 0.5772156649015328606 + +# ===----------------------------------------------------------------------=== # +# Helper functions +# ===----------------------------------------------------------------------=== # + + +fn _factorial(n: Int) -> Float64: + var res = 1.0 + for i in range(2, n + 1): + res *= Float64(i) + return res + + +# ===----------------------------------------------------------------------=== # +# Bessel functions of the first kind +# ===----------------------------------------------------------------------=== # + + +fn j0(x: Float64) -> Float64: + """Bessel function of the first kind of order 0. + + Args: + x: Input value. + + Returns: + J₀(x). + + Examples: + ```mojo + from stamojo.special import j0 + from testing import assert_almost_equal + + fn main() raises: + assert_almost_equal(j0(1.0), 0.7651976865579666, atol=1e-12) + ``` + """ + # J₀ is even: J₀(-x) = J₀(x). + var ax = abs(x) + + # TODO: Determine asymptotic threshold empirically. + if ax > 10.0: + var t = ax - _PI * 0.25 + return sqrt(2.0 * _PI_INV / ax) * cos(t) + + var term = 1.0 + var res = term + var x2 = ax * ax * 0.25 + for k in range(1, _MAX_SERIES_ITER): + term *= -x2 / (Float64(k) * Float64(k)) + res += term + return res + + +fn j1(x: Float64) -> Float64: + """Bessel function of the first kind of order 1. + + Args: + x: Input value. + + Returns: + J₁(x). + + Examples: + ```mojo + from stamojo.special import j1 + from testing import assert_almost_equal + + fn main() raises: + assert_almost_equal(j1(1.0), 0.44005058574493355, atol=1e-12) + ``` + """ + # J₁ is odd: J₁(-x) = -J₁(x). + var ax = abs(x) + var sign: Float64 = 1.0 if x >= 0.0 else -1.0 + + # TODO: Determine asymptotic threshold empirically. + if ax > 10.0: + var t = ax - 3.0 * _PI * 0.25 + return sign * sqrt(2.0 * _PI_INV / ax) * cos(t) + + var term = ax * 0.5 + var res = term + var x2 = ax * ax * 0.25 + for k in range(1, _MAX_SERIES_ITER): + term *= -x2 / (Float64(k) * Float64(k + 1)) + res += term + return sign * res + + +fn jn[n: Int](x: Float64) -> Float64: + """Bessel function of the first kind of order *n*. + + Parameters: + n: Order of the Bessel function (integer). + + Args: + x: Input value. + + Returns: + Jₙ(x). + """ + + @parameter + if n == 0: + return j0(x) + + @parameter + if n == 1: + return j1(x) + + comptime m = n if n >= 0 else -n + comptime sign: Float64 = -1.0 if (n < 0 and (m % 2 == 1)) else 1.0 + + var ax = abs(x) + + # For m <= ax, use the forward recurrence: + # J_{k+1}(x) = (2k/x) J_k(x) - J_{k-1}(x) + if Float64(m) <= ax: + var jm1 = j0(x) # J_0 + var jcur = j1(x) # J_1 + for k in range(1, m): + var jnext = (2.0 * Float64(k) / x) * jcur - jm1 + jm1 = jcur + jcur = jnext + return sign * jcur + + # For m > ax, use power series. + var fact = _factorial(m) + var term = 1.0 + for _ in range(m): + term *= x * 0.5 + term /= fact + + var res = term + var x2 = x * x * 0.25 + for k in range(1, _MAX_SERIES_ITER): + term *= -x2 / (Float64(k) * Float64(k + m)) + res += term + return sign * res + + +# ===----------------------------------------------------------------------=== # +# Modified Bessel functions of the first kind and their scaled forms +# ===----------------------------------------------------------------------=== # + + +fn i0(x: Float64) -> Float64: + """Modified Bessel function of the first kind of order 0. + + Args: + x: Input value. + + Returns: + I₀(x). + + Examples: + ```mojo + from stamojo.special import i0 + from testing import assert_almost_equal + + fn main() raises: + assert_almost_equal(i0(1.0), 1.2660658777520082, atol=1e-12) + ``` + """ + var term = 1.0 + var res = term + var x2 = x * x * 0.25 + for k in range(1, _MAX_SERIES_ITER): + term *= x2 / (Float64(k) * Float64(k)) + res += term + return res + + +fn i1(x: Float64) -> Float64: + """Modified Bessel function of the first kind of order 1. + + Args: + x: Input value. + + Returns: + I₁(x). + + Examples: + ```mojo + from stamojo.special import i1 + from testing import assert_almost_equal + + fn main() raises: + assert_almost_equal(i1(1.0), 0.5651591039924851, atol=1e-12) + ``` + """ + var term = x * 0.5 + var res = term + var x2 = x * x * 0.25 + for k in range(1, _MAX_SERIES_ITER): + term *= x2 / (Float64(k) * Float64(k + 1)) + res += term + return res + + +fn i0e(x: Float64) -> Float64: + """Exponentially scaled modified Bessel function of the first kind + of order 0: ``i0e(x) = exp(-|x|) * i0(x)``. + + Args: + x: Input value. + + Returns: + Value of exp(-|x|) * I₀(x). + + Examples: + ```mojo + from stamojo.special import i0e + from testing import assert_almost_equal + + fn main() raises: + assert_almost_equal(i0e(1.0), 0.4657596075936405, atol=1e-12) + ``` + """ + return i0(x) * exp(-abs(x)) + + +fn i1e(x: Float64) -> Float64: + """Exponentially scaled modified Bessel function of the first kind + of order 1: ``i1e(x) = exp(-|x|) * i1(x)``. + + Args: + x: Input value. + + Returns: + Value of exp(-|x|) * I₁(x). + + Examples: + ```mojo + from stamojo.special import i1e + from testing import assert_almost_equal + + fn main() raises: + assert_almost_equal(i1e(1.0), 0.2079104153497085, atol=1e-12) + ``` + """ + return i1(x) * exp(-abs(x)) + + +# ===----------------------------------------------------------------------=== # +# Bessel functions of the second kind +# ===----------------------------------------------------------------------=== # + + +fn y0(x: Float64) -> Float64: + """Bessel function of the second kind of order 0. + + Defined for x > 0. Returns -∞ at x = 0 and NaN for x < 0. + + Args: + x: Input value (must be positive). + + Returns: + Y₀(x). + + Examples: + ```mojo + from stamojo.special import y0 + from testing import assert_almost_equal + + fn main() raises: + assert_almost_equal(y0(1.0), 0.08825696421567697, atol=1e-12) + ``` + """ + if x == 0.0: + return -inf[DType.float64]() + if x < 0.0: + return nan[DType.float64]() + + if x < 8.0: + var j0x = j0(x) + var x2 = x * x * 0.25 + var term = x2 + var sum = 0.0 + var h = 1.0 + for k in range(1, _MAX_SERIES_ITER): + if k > 1: + h += 1.0 / Float64(k) + sum += term * h + term *= -x2 / (Float64(k + 1) * Float64(k + 1)) + return (2.0 / _PI) * ((log(x * 0.5) + _EULER_GAMMA) * j0x + sum) + + var t = x - _PI * 0.25 + return sqrt(2.0 / (_PI * x)) * sin(t) + + +fn y1(x: Float64) -> Float64: + """Bessel function of the second kind of order 1. + + Defined for x > 0. Returns -∞ at x = 0 and NaN for x < 0. + + Args: + x: Input value (must be positive). + + Returns: + Y₁(x). + + Examples: + ```mojo + from stamojo.special import y1 + from testing import assert_almost_equal + + fn main() raises: + assert_almost_equal(y1(1.0), -0.7812128213002887, atol=1e-12) + ``` + """ + if x == 0.0: + return -inf[DType.float64]() + if x < 0.0: + return nan[DType.float64]() + + if x < 8.0: + var j1x = j1(x) + var x2 = x * x * 0.25 + var term = x * 0.5 + var sum = 0.0 + var hk = 0.0 + + for k in range(_MAX_SERIES_ITER): + var hk1 = hk + 1.0 / Float64(k + 1) + sum += term * (hk + hk1) + term *= -x2 / (Float64(k + 1) * Float64(k + 2)) + hk = hk1 + + return ( + (-2.0 * _PI_INV / x) + + (2.0 * _PI_INV) * (log(x * 0.5) + _EULER_GAMMA) * j1x + - sum * _PI_INV + ) + + var t = x - 3.0 * _PI * 0.25 + return sqrt(2.0 / (_PI * x)) * sin(t) diff --git a/tests/test_distributions.mojo b/tests/test_distributions.mojo index 50707cc..e0272b4 100644 --- a/tests/test_distributions.mojo +++ b/tests/test_distributions.mojo @@ -62,7 +62,6 @@ fn test_normal_pdf() raises: # Non-standard normal: N(5, 2), pdf at mean = 1/(σ√(2π)) var n2 = Normal(5.0, 2.0) assert_almost_equal(n2.pdf(5.0), 0.19947114020071635, atol=1e-12) - print("✓ test_normal_pdf passed") fn test_normal_cdf() raises: @@ -74,7 +73,6 @@ fn test_normal_cdf() raises: # Tails assert_almost_equal(n.cdf(-10.0), 0.0, atol=1e-15) assert_almost_equal(n.cdf(10.0), 1.0, atol=1e-15) - print("✓ test_normal_cdf passed") fn test_normal_ppf() raises: @@ -88,7 +86,6 @@ fn test_normal_ppf() raises: var n2 = Normal(10.0, 3.0) assert_almost_equal(n2.ppf(0.5), 10.0, atol=1e-10) assert_almost_equal(n2.ppf(n2.cdf(15.0)), 15.0, atol=1e-10) - print("✓ test_normal_ppf passed") fn test_normal_cdf_ppf_roundtrip() raises: @@ -100,15 +97,12 @@ fn test_normal_cdf_ppf_roundtrip() raises: var p = ps[i] assert_almost_equal(n.cdf(n.ppf(p)), p, atol=1e-10) - print("✓ test_normal_cdf_ppf_roundtrip passed") - fn test_normal_sf() raises: """Test Normal survival function.""" var n = Normal(0.0, 1.0) assert_almost_equal(n.sf(0.0), 0.5, atol=1e-15) assert_almost_equal(n.cdf(1.5) + n.sf(1.5), 1.0, atol=1e-15) - print("✓ test_normal_sf passed") fn test_normal_stats() raises: @@ -117,7 +111,6 @@ fn test_normal_stats() raises: assert_almost_equal(n.mean(), 3.0, atol=1e-15) assert_almost_equal(n.variance(), 4.0, atol=1e-15) assert_almost_equal(n.std(), 2.0, atol=1e-15) - print("✓ test_normal_stats passed") fn test_normal_scipy() raises: @@ -137,8 +130,6 @@ fn test_normal_scipy() raises: assert_almost_equal(n.pdf(x), sp_pdf, atol=1e-12) assert_almost_equal(n.cdf(x), sp_cdf, atol=1e-12) - print("✓ test_normal_scipy passed") - # ===----------------------------------------------------------------------=== # # Student's t distribution tests @@ -150,7 +141,6 @@ fn test_t_pdf_symmetry() raises: var t = StudentT(5.0) assert_almost_equal(t.pdf(1.0), t.pdf(-1.0), atol=1e-15) assert_almost_equal(t.pdf(2.5), t.pdf(-2.5), atol=1e-15) - print("✓ test_t_pdf_symmetry passed") fn test_t_cdf() raises: @@ -164,7 +154,6 @@ fn test_t_cdf() raises: var t5 = StudentT(5.0) assert_almost_equal(t5.cdf(0.0), 0.5, atol=1e-12) assert_almost_equal(t5.cdf(2.0) + t5.cdf(-2.0), 1.0, atol=1e-10) - print("✓ test_t_cdf passed") fn test_t_ppf() raises: @@ -175,7 +164,6 @@ fn test_t_ppf() raises: assert_almost_equal(t5.cdf(t5.ppf(0.975)), 0.975, atol=1e-6) assert_almost_equal(t5.cdf(t5.ppf(0.025)), 0.025, atol=1e-6) assert_almost_equal(t5.cdf(t5.ppf(0.9)), 0.9, atol=1e-6) - print("✓ test_t_ppf passed") fn test_t_stats() raises: @@ -183,7 +171,6 @@ fn test_t_stats() raises: var t5 = StudentT(5.0) assert_almost_equal(t5.mean(), 0.0, atol=1e-15) assert_almost_equal(t5.variance(), 5.0 / 3.0, atol=1e-12) - print("✓ test_t_stats passed") fn test_t_scipy() raises: @@ -207,8 +194,6 @@ fn test_t_scipy() raises: assert_almost_equal(t.pdf(x), sp_pdf, atol=1e-10) assert_almost_equal(t.cdf(x), sp_cdf, atol=1e-6) - print("✓ test_t_scipy passed") - # ===----------------------------------------------------------------------=== # # Chi-squared distribution tests @@ -224,7 +209,6 @@ fn test_chi2_cdf() raises: assert_almost_equal(c2.cdf(2.0), 1.0 - exp(-1.0), atol=1e-10) assert_almost_equal(c2.cdf(4.0), 1.0 - exp(-2.0), atol=1e-10) assert_almost_equal(c2.cdf(0.0), 0.0, atol=1e-15) - print("✓ test_chi2_cdf passed") fn test_chi2_ppf() raises: @@ -233,7 +217,6 @@ fn test_chi2_ppf() raises: assert_almost_equal(c5.cdf(c5.ppf(0.95)), 0.95, atol=1e-6) assert_almost_equal(c5.cdf(c5.ppf(0.5)), 0.5, atol=1e-6) assert_almost_equal(c5.cdf(c5.ppf(0.01)), 0.01, atol=1e-6) - print("✓ test_chi2_ppf passed") fn test_chi2_stats() raises: @@ -241,7 +224,6 @@ fn test_chi2_stats() raises: var c5 = ChiSquared(5.0) assert_almost_equal(c5.mean(), 5.0, atol=1e-15) assert_almost_equal(c5.variance(), 10.0, atol=1e-15) - print("✓ test_chi2_stats passed") fn test_chi2_scipy() raises: @@ -263,8 +245,6 @@ fn test_chi2_scipy() raises: var sp_cdf = _py_f64(sp.chi2.cdf(x, df)) assert_almost_equal(c.cdf(x), sp_cdf, atol=1e-6) - print("✓ test_chi2_scipy passed") - # ===----------------------------------------------------------------------=== # # F-distribution tests @@ -281,7 +261,6 @@ fn test_f_cdf_boundary() raises: var c3 = f.cdf(5.0) if not (c1 < c2 and c2 < c3): raise Error("F CDF not monotonically increasing") - print("✓ test_f_cdf_boundary passed") fn test_f_ppf() raises: @@ -290,7 +269,6 @@ fn test_f_ppf() raises: assert_almost_equal(f.cdf(f.ppf(0.95)), 0.95, atol=1e-6) assert_almost_equal(f.cdf(f.ppf(0.5)), 0.5, atol=1e-6) assert_almost_equal(f.cdf(f.ppf(0.1)), 0.1, atol=1e-6) - print("✓ test_f_ppf passed") fn test_f_stats() raises: @@ -298,7 +276,6 @@ fn test_f_stats() raises: var f = FDist(5.0, 10.0) # mean = d2 / (d2 - 2) = 10/8 = 1.25 assert_almost_equal(f.mean(), 1.25, atol=1e-12) - print("✓ test_f_stats passed") fn test_f_scipy() raises: @@ -316,8 +293,6 @@ fn test_f_scipy() raises: var sp_cdf = _py_f64(sp.f.cdf(x, 5.0, 10.0)) assert_almost_equal(f.cdf(x), sp_cdf, atol=1e-6) - print("✓ test_f_scipy passed") - # ===----------------------------------------------------------------------=== # # Exponential distribution tests @@ -343,7 +318,6 @@ fn test_expon_pdf() raises: var e3 = Exponential(loc=1.0) assert_almost_equal(e3.pdf(1.0), 1.0, atol=1e-15) assert_almost_equal(e3.pdf(0.5), 0.0, atol=1e-15) - print("✓ test_expon_pdf passed") fn test_expon_logpdf() raises: @@ -358,7 +332,6 @@ fn test_expon_logpdf() raises: # With scale=3: logpdf(x) = -x/3 - log(3) var e2 = Exponential(scale=3.0) assert_almost_equal(e2.logpdf(3.0), -1.0 - log(3.0), atol=1e-15) - print("✓ test_expon_logpdf passed") fn test_expon_cdf() raises: @@ -379,7 +352,6 @@ fn test_expon_cdf() raises: # With scale=0.5 (rate=2): CDF(x) = 1 - exp(-2x) var e2 = Exponential(scale=0.5) assert_almost_equal(e2.cdf(1.0), 1.0 - exp(-2.0), atol=1e-15) - print("✓ test_expon_cdf passed") fn test_expon_sf() raises: @@ -393,7 +365,6 @@ fn test_expon_sf() raises: assert_almost_equal(e.cdf(xs[i]) + e.sf(xs[i]), 1.0, atol=1e-15) # SF(x < loc) = 1 assert_almost_equal(e.sf(-1.0), 1.0, atol=1e-15) - print("✓ test_expon_sf passed") fn test_expon_ppf() raises: @@ -408,7 +379,6 @@ fn test_expon_ppf() raises: # With loc and scale var e2 = Exponential(loc=1.0, scale=2.0) assert_almost_equal(e2.ppf(0.5), 1.0 + 2.0 * log(2.0), atol=1e-12) - print("✓ test_expon_ppf passed") fn test_expon_cdf_ppf_roundtrip() raises: @@ -423,7 +393,6 @@ fn test_expon_cdf_ppf_roundtrip() raises: for i in range(len(ps)): var p = ps[i] assert_almost_equal(e2.cdf(e2.ppf(p)), p, atol=1e-12) - print("✓ test_expon_cdf_ppf_roundtrip passed") fn test_expon_isf() raises: @@ -438,7 +407,6 @@ fn test_expon_isf() raises: for i in range(len(qs)): var q = qs[i] assert_almost_equal(e.isf(q), e.ppf(1.0 - q), atol=1e-12) - print("✓ test_expon_isf passed") fn test_expon_logcdf_logsf() raises: @@ -449,7 +417,6 @@ fn test_expon_logcdf_logsf() raises: var x = xs[i] assert_almost_equal(e.logcdf(x), log(e.cdf(x)), atol=1e-12) assert_almost_equal(e.logsf(x), log(e.sf(x)), atol=1e-15) - print("✓ test_expon_logcdf_logsf passed") fn test_expon_stats() raises: @@ -466,7 +433,6 @@ fn test_expon_stats() raises: assert_almost_equal(e2.variance(), 9.0, atol=1e-15) assert_almost_equal(e2.std(), 3.0, atol=1e-15) assert_almost_equal(e2.median(), 2.0 + 3.0 * log(2.0), atol=1e-15) - print("✓ test_expon_stats passed") fn test_expon_loc_scale() raises: @@ -482,7 +448,6 @@ fn test_expon_loc_scale() raises: assert_almost_equal(e.sf(loc), 1.0, atol=1e-15) # CDF(loc + scale) = 1 - exp(-1) assert_almost_equal(e.cdf(loc + scale), 1.0 - exp(-1.0), atol=1e-15) - print("✓ test_expon_loc_scale passed") fn test_expon_scipy() raises: @@ -522,8 +487,6 @@ fn test_expon_scipy() raises: assert_almost_equal(e2.pdf(x), sp_pdf2, atol=1e-12) assert_almost_equal(e2.cdf(x), sp_cdf2, atol=1e-12) - print("✓ test_expon_scipy passed") - # ===----------------------------------------------------------------------=== # # Main test runner diff --git a/tests/test_hypothesis.mojo b/tests/test_hypothesis.mojo index defb741..c79ae2d 100644 --- a/tests/test_hypothesis.mojo +++ b/tests/test_hypothesis.mojo @@ -71,8 +71,6 @@ fn test_ttest_1samp_basic() raises: if p_val > 0.05: raise Error("ttest_1samp: expected p < 0.05, got " + String(p_val)) - print("✓ test_ttest_1samp_basic passed") - fn test_ttest_1samp_no_effect() raises: """Test one-sample t-test when data mean ≈ mu0.""" @@ -83,8 +81,6 @@ fn test_ttest_1samp_no_effect() raises: # p-value should be 1.0 (no evidence against H0) assert_almost_equal(result[1], 1.0, atol=1e-6) - print("✓ test_ttest_1samp_no_effect passed") - fn test_ttest_1samp_scipy() raises: """Test one-sample t-test against scipy.""" @@ -104,8 +100,6 @@ fn test_ttest_1samp_scipy() raises: assert_almost_equal(result[0], sp_t, atol=1e-6) assert_almost_equal(result[1], sp_p, atol=1e-4) - print("✓ test_ttest_1samp_scipy passed") - fn test_ttest_ind_welch() raises: """Test Welch's two-sample t-test.""" @@ -120,8 +114,6 @@ fn test_ttest_ind_welch() raises: "ttest_ind Welch: expected p < 0.05, got " + String(result[1]) ) - print("✓ test_ttest_ind_welch passed") - fn test_ttest_ind_scipy() raises: """Test Welch's t-test against scipy.""" @@ -144,8 +136,6 @@ fn test_ttest_ind_scipy() raises: assert_almost_equal(result[0], sp_t, atol=1e-4) assert_almost_equal(result[1], sp_p, atol=1e-3) - print("✓ test_ttest_ind_scipy passed") - fn test_ttest_rel() raises: """Test paired t-test.""" @@ -162,8 +152,6 @@ fn test_ttest_rel() raises: assert_almost_equal(result[0], result_1samp[0], atol=1e-12) assert_almost_equal(result[1], result_1samp[1], atol=1e-12) - print("✓ test_ttest_rel passed") - # ===----------------------------------------------------------------------=== # # Chi-squared tests @@ -185,8 +173,6 @@ fn test_chi2_gof_fair_die() raises: "chi2_gof: expected non-significant, got p=" + String(result[1]) ) - print("✓ test_chi2_gof_fair_die passed") - fn test_chi2_ind_basic() raises: """Test chi-squared independence test with 2×2 table.""" @@ -202,8 +188,6 @@ fn test_chi2_ind_basic() raises: assert_almost_equal(result[0], 0.0, atol=1e-10) assert_almost_equal(result[1], 1.0, atol=1e-4) - print("✓ test_chi2_ind_basic passed") - fn test_chi2_ind_scipy() raises: """Test chi-squared independence test against scipy.""" @@ -229,8 +213,6 @@ fn test_chi2_ind_scipy() raises: assert_almost_equal(result[0], sp_chi2, atol=1e-4) assert_almost_equal(result[1], sp_p, atol=1e-3) - print("✓ test_chi2_ind_scipy passed") - # ===----------------------------------------------------------------------=== # # Kolmogorov-Smirnov test @@ -260,8 +242,6 @@ fn test_ks_normal_data() raises: + String(result[1]) ) - print("✓ test_ks_normal_data passed") - fn test_ks_uniform_data() raises: """Test KS test with uniform data (should reject N(0,1)).""" @@ -278,8 +258,6 @@ fn test_ks_uniform_data() raises: + String(result[1]) ) - print("✓ test_ks_uniform_data passed") - # ===----------------------------------------------------------------------=== # # Correlation tests @@ -300,8 +278,6 @@ fn test_pearsonr_perfect() raises: if result[1] > 0.001: raise Error("pearsonr: expected p ≈ 0 for perfect correlation") - print("✓ test_pearsonr_perfect passed") - fn test_pearsonr_negative() raises: """Test Pearson correlation for negative correlation.""" @@ -314,8 +290,6 @@ fn test_pearsonr_negative() raises: var result = pearsonr(x, y) assert_almost_equal(result[0], -1.0, atol=1e-10) - print("✓ test_pearsonr_negative passed") - fn test_pearsonr_scipy() raises: """Test Pearson correlation against scipy.""" @@ -338,8 +312,6 @@ fn test_pearsonr_scipy() raises: assert_almost_equal(result[0], sp_r, atol=1e-6) assert_almost_equal(result[1], sp_p, atol=1e-3) - print("✓ test_pearsonr_scipy passed") - fn test_spearmanr_perfect_monotone() raises: """Test Spearman correlation with perfect monotone data.""" @@ -352,8 +324,6 @@ fn test_spearmanr_perfect_monotone() raises: var result = spearmanr(x, y) assert_almost_equal(result[0], 1.0, atol=1e-10) - print("✓ test_spearmanr_perfect_monotone passed") - fn test_spearmanr_scipy() raises: """Test Spearman correlation against scipy.""" @@ -374,8 +344,6 @@ fn test_spearmanr_scipy() raises: var result = spearmanr(x, y) assert_almost_equal(result[0], sp_rho, atol=1e-4) - print("✓ test_spearmanr_scipy passed") - fn test_kendalltau_concordant() raises: """Test Kendall's tau with perfectly concordant data.""" @@ -388,8 +356,6 @@ fn test_kendalltau_concordant() raises: var result = kendalltau(x, y) assert_almost_equal(result[0], 1.0, atol=1e-10) - print("✓ test_kendalltau_concordant passed") - fn test_kendalltau_discordant() raises: """Test Kendall's tau with perfectly discordant data.""" @@ -402,8 +368,6 @@ fn test_kendalltau_discordant() raises: var result = kendalltau(x, y) assert_almost_equal(result[0], -1.0, atol=1e-10) - print("✓ test_kendalltau_discordant passed") - # ===----------------------------------------------------------------------=== # # One-way ANOVA @@ -425,8 +389,6 @@ fn test_f_oneway_identical() raises: assert_almost_equal(result[0], 0.0, atol=1e-10) assert_almost_equal(result[1], 1.0, atol=1e-4) - print("✓ test_f_oneway_identical passed") - fn test_f_oneway_different() raises: """Test ANOVA with clearly different group means.""" @@ -447,8 +409,6 @@ fn test_f_oneway_different() raises: + String(result[1]) ) - print("✓ test_f_oneway_different passed") - fn test_f_oneway_scipy() raises: """Test ANOVA against scipy.stats.f_oneway.""" @@ -477,8 +437,6 @@ fn test_f_oneway_scipy() raises: assert_almost_equal(result[0], sp_f, atol=1e-4) assert_almost_equal(result[1], sp_p, atol=1e-3) - print("✓ test_f_oneway_scipy passed") - # ===----------------------------------------------------------------------=== # # Main test runner diff --git a/tests/test_special.mojo b/tests/test_special.mojo index 4610d76..fdc10ff 100644 --- a/tests/test_special.mojo +++ b/tests/test_special.mojo @@ -20,7 +20,23 @@ from math import exp, log, lgamma, erf, sqrt from python import Python, PythonObject from testing import assert_almost_equal, TestSuite -from stamojo.special import gammainc, gammaincc, beta, lbeta, betainc, erfinv +from stamojo.special import ( + gammainc, + gammaincc, + beta, + lbeta, + betainc, + erfinv, + j0, + j1, + jn, + i0, + i1, + i0e, + i1e, + y0, + y1, +) # ===----------------------------------------------------------------------=== # @@ -92,7 +108,6 @@ fn test_gammainc_boundary() raises: assert_almost_equal(gammainc(5.0, 0.0), 0.0, atol=1e-15) assert_almost_equal(gammaincc(1.0, 0.0), 1.0, atol=1e-15) assert_almost_equal(gammaincc(5.0, 0.0), 1.0, atol=1e-15) - print("✓ test_gammainc_boundary passed") fn test_gammainc_exponential() raises: @@ -113,8 +128,6 @@ fn test_gammainc_exponential() raises: atol=1e-12, ) - print("✓ test_gammainc_exponential passed") - fn test_gammainc_half() raises: """Test P(0.5, x) = erf(sqrt(x)).""" @@ -134,8 +147,6 @@ fn test_gammainc_half() raises: atol=1e-10, ) - print("✓ test_gammainc_half passed") - fn test_gammainc_integer_a() raises: """Test gammainc/gammaincc against the Poisson sum formula for integer a.""" @@ -170,8 +181,6 @@ fn test_gammainc_integer_a() raises: "gammaincc(" + String(a_int) + ", " + String(x) + ")", ) - print("✓ test_gammainc_integer_a passed") - fn test_gammainc_scipy() raises: """Test gammainc/gammaincc against scipy for non-integer a values.""" @@ -207,8 +216,6 @@ fn test_gammainc_scipy() raises: atol=1e-10, ) - print("✓ test_gammainc_scipy passed") - fn test_gammainc_complementary() raises: """Test P(a,x) + Q(a,x) = 1.""" @@ -228,8 +235,6 @@ fn test_gammainc_complementary() raises: var x = test_cases[i][1] assert_almost_equal(gammainc(a, x) + gammaincc(a, x), 1.0, atol=1e-12) - print("✓ test_gammainc_complementary passed") - # ===----------------------------------------------------------------------=== # # Tests for beta and incomplete beta @@ -248,15 +253,12 @@ fn test_beta_basic() raises: var expected = exp(lgamma(a) + lgamma(b) - lgamma(a + b)) assert_almost_equal(beta(a, b), expected, atol=1e-12) - print("✓ test_beta_basic passed") - fn test_betainc_boundary() raises: """Test betainc boundary values.""" assert_almost_equal(betainc(2.0, 3.0, 0.0), 0.0, atol=1e-15) assert_almost_equal(betainc(2.0, 3.0, 1.0), 1.0, atol=1e-15) assert_almost_equal(betainc(1.0, 1.0, 0.5), 0.5, atol=1e-12) - print("✓ test_betainc_boundary passed") fn test_betainc_symmetric() raises: @@ -267,8 +269,6 @@ fn test_betainc_symmetric() raises: var a = test_a[i] assert_almost_equal(betainc(a, a, 0.5), 0.5, atol=1e-10) - print("✓ test_betainc_symmetric passed") - fn test_betainc_symmetry_identity() raises: """Test I_x(a,b) = 1 - I_{1-x}(b,a).""" @@ -282,7 +282,6 @@ fn test_betainc_symmetry_identity() raises: 1.0 - betainc(7.0, 2.0, 0.7), atol=1e-10, ) - print("✓ test_betainc_symmetry_identity passed") fn test_betainc_known_values() raises: @@ -291,7 +290,6 @@ fn test_betainc_known_values() raises: assert_almost_equal(betainc(1.0, 1.0, x), x, atol=1e-12) assert_almost_equal(betainc(1.0, 2.0, x), 1.0 - (1.0 - x) ** 2, atol=1e-10) assert_almost_equal(betainc(1.0, 5.0, x), 1.0 - (1.0 - x) ** 5, atol=1e-10) - print("✓ test_betainc_known_values passed") fn test_betainc_scipy() raises: @@ -320,8 +318,6 @@ fn test_betainc_scipy() raises: atol=1e-10, ) - print("✓ test_betainc_scipy passed") - # ===----------------------------------------------------------------------=== # # Tests for inverse error function @@ -349,8 +345,6 @@ fn test_erfinv_basic() raises: var x = erfinv(p) assert_almost_equal(erf(x), p, atol=1e-8) - print("✓ test_erfinv_basic passed") - fn test_erfinv_symmetry() raises: """Test erfinv(-p) = -erfinv(p).""" @@ -360,8 +354,6 @@ fn test_erfinv_symmetry() raises: var p = test_vals[i] assert_almost_equal(erfinv(-p), -erfinv(p), atol=1e-12) - print("✓ test_erfinv_symmetry passed") - fn test_erfinv_scipy() raises: """Test erfinv against scipy.special.erfinv.""" @@ -396,7 +388,122 @@ fn test_erfinv_scipy() raises: atol=1e-10, ) - print("✓ test_erfinv_scipy passed") + +# ===----------------------------------------------------------------------=== # +# Tests for Bessel functions +# ===----------------------------------------------------------------------=== # + + +fn test_bessel_basic_values() raises: + """Test basic Bessel function values.""" + assert_almost_equal(j0(0.0), 1.0, atol=1e-15) + assert_almost_equal(j1(0.0), 0.0, atol=1e-15) + assert_almost_equal(jn[1](0.0), 0.0, atol=1e-15) + assert_almost_equal(jn[2](0.0), 0.0, atol=1e-15) + assert_almost_equal(i0(0.0), 1.0, atol=1e-15) + assert_almost_equal(i1(0.0), 0.0, atol=1e-15) + assert_almost_equal(y0(1.0), 0.08825696421567697, atol=1e-12) + assert_almost_equal(y1(1.0), -0.7812128213002887, atol=1e-12) + + +fn test_bessel_symmetry() raises: + """Test symmetry relations for Bessel functions.""" + var x = 2.5 + assert_almost_equal(j0(-x), j0(x), atol=1e-12) + assert_almost_equal(j1(-x), -j1(x), atol=1e-12) + assert_almost_equal(jn[2](-x), jn[2](x), atol=1e-12) + + +fn test_bessel_scaled() raises: + """Test scaled modified Bessel functions.""" + var x = 2.0 + assert_almost_equal(i0e(x), i0(x) * exp(-x), atol=1e-12) + assert_almost_equal(i1e(x), i1(x) * exp(-x), atol=1e-12) + + +fn test_bessel_scipy() raises: + """Test Bessel functions against scipy.special.""" + var sp = _load_scipy() + if sp is None: + print("test_bessel_scipy skipped (scipy not available)") + return + + var xs: List[Float64] = [1e-3, 0.5, 1.0, 2.5] + + for i in range(len(xs)): + var x = xs[i] + var sp_j0 = _py_f64(sp.j0(x)) + var sp_j1 = _py_f64(sp.j1(x)) + var sp_i0 = _py_f64(sp.i0(x)) + var sp_i1 = _py_f64(sp.i1(x)) + var sp_i0e = _py_f64(sp.i0e(x)) + var sp_i1e = _py_f64(sp.i1e(x)) + var sp_y0 = _py_f64(sp.y0(x)) + var sp_y1 = _py_f64(sp.y1(x)) + _assert_with_scipy( + j0(x), + sp_j0, + sp, + sp_j0, + "j0(" + String(x) + ")", + atol=1e-10, + ) + _assert_with_scipy( + j1(x), + sp_j1, + sp, + sp_j1, + "j1(" + String(x) + ")", + atol=1e-10, + ) + _assert_with_scipy( + i0(x), + sp_i0, + sp, + sp_i0, + "i0(" + String(x) + ")", + atol=1e-10, + ) + _assert_with_scipy( + i1(x), + sp_i1, + sp, + sp_i1, + "i1(" + String(x) + ")", + atol=1e-10, + ) + _assert_with_scipy( + i0e(x), + sp_i0e, + sp, + sp_i0e, + "i0e(" + String(x) + ")", + atol=1e-10, + ) + _assert_with_scipy( + i1e(x), + sp_i1e, + sp, + sp_i1e, + "i1e(" + String(x) + ")", + atol=1e-10, + ) + _assert_with_scipy( + y0(x), + sp_y0, + sp, + sp_y0, + "y0(" + String(x) + ")", + atol=1e-10, + ) + _assert_with_scipy( + y1(x), + sp_y1, + sp, + sp_y1, + "y1(" + String(x) + ")", + atol=1e-10, + ) # ===----------------------------------------------------------------------=== # diff --git a/tests/test_stats.mojo b/tests/test_stats.mojo index 6ca44ce..d5df865 100644 --- a/tests/test_stats.mojo +++ b/tests/test_stats.mojo @@ -51,8 +51,6 @@ fn test_mean() raises: var data2: List[Float64] = [10.0] assert_almost_equal(mean(data2), 10.0, atol=1e-15) - print("✓ test_mean passed") - fn test_variance() raises: """Test variance (population and sample).""" @@ -63,8 +61,6 @@ fn test_variance() raises: # Sample variance = 32/7 assert_almost_equal(variance(data, ddof=1), 32.0 / 7.0, atol=1e-12) - print("✓ test_variance passed") - fn test_std() raises: """Test standard deviation.""" @@ -72,8 +68,6 @@ fn test_std() raises: assert_almost_equal(std(data, ddof=0), 2.0, atol=1e-12) - print("✓ test_std passed") - fn test_median_odd() raises: """Test median with odd-length data.""" @@ -83,16 +77,12 @@ fn test_median_odd() raises: var data2: List[Float64] = [5.0, 1.0, 3.0, 2.0, 4.0] assert_almost_equal(median(data2), 3.0, atol=1e-15) - print("✓ test_median_odd passed") - fn test_median_even() raises: """Test median with even-length data.""" var data: List[Float64] = [3.0, 1.0, 2.0, 4.0] assert_almost_equal(median(data), 2.5, atol=1e-15) - print("✓ test_median_even passed") - fn test_quantile() raises: """Test quantile function.""" @@ -108,16 +98,12 @@ fn test_quantile() raises: # q=0.25 assert_almost_equal(quantile(data, 0.25), 3.25, atol=1e-12) - print("✓ test_quantile passed") - fn test_skewness_symmetric() raises: """Test skewness of perfectly symmetric data is 0.""" var data: List[Float64] = [1.0, 2.0, 3.0, 4.0, 5.0] assert_almost_equal(skewness(data), 0.0, atol=1e-12) - print("✓ test_skewness_symmetric passed") - fn test_kurtosis_uniform() raises: """Test kurtosis of uniform-like data is negative (platykurtic).""" @@ -132,8 +118,6 @@ fn test_kurtosis_uniform() raises: "Kurtosis of uniform data out of expected range: " + String(k) ) - print("✓ test_kurtosis_uniform passed") - fn test_min_max() raises: """Test data_min and data_max.""" @@ -142,8 +126,6 @@ fn test_min_max() raises: assert_almost_equal(data_min(data), 1.0, atol=1e-15) assert_almost_equal(data_max(data), 9.0, atol=1e-15) - print("✓ test_min_max passed") - fn test_scipy_comparison() raises: """Test descriptive statistics against numpy/scipy."""