diff --git a/src/stamojo/stats/__init__.mojo b/src/stamojo/stats/__init__.mojo index 892665a..77e64b7 100644 --- a/src/stamojo/stats/__init__.mojo +++ b/src/stamojo/stats/__init__.mojo @@ -20,6 +20,8 @@ from .descriptive import ( kurtosis, data_min, data_max, + gmean, + hmean, ) from .correlation import ( diff --git a/src/stamojo/stats/descriptive.mojo b/src/stamojo/stats/descriptive.mojo index 62e4e6b..7a84883 100644 --- a/src/stamojo/stats/descriptive.mojo +++ b/src/stamojo/stats/descriptive.mojo @@ -15,9 +15,11 @@ Provides functions for computing summary statistics of ``List[Float64]`` data: - ``kurtosis`` — (Excess) kurtosis (bias-corrected) - ``data_min`` — Minimum value - ``data_max`` — Maximum value +- ``gmean`` — Geometric mean +- ``hmean`` — Harmonic mean """ -from math import sqrt, nan +from math import sqrt, nan, log, exp # ===----------------------------------------------------------------------=== # @@ -259,3 +261,115 @@ fn data_max(data: List[Float64]) -> Float64: if data[i] > result: result = data[i] return result + + +# TODO: Due to limitation in mojo compiler in 0.26.1 (resolved in nightly), we can't have Optional[List]. +# Once we have that, we can make weights optional and handle the unweighted case more cleanly. +# For now, we can just require an empty list for unweighted case. +fn gmean(data: List[Float64], weights: List[Float64]) -> Float64: + """Compute the weighted geometric mean of a list of values. + + The geometric mean is the nth root of the product of n values. If weights are provided, + computes the weighted geometric mean using the formula: + + exp(Σ(wᵢ · log(xᵢ)) / Σwᵢ) + + where xᵢ are the data values and wᵢ are the corresponding weights. + + Args: + data: A list of values, which must all be non-negative. + weights: A list of weights corresponding to each value in `data`. + If weights are provided, they must be the same length as `data`, + all weights must be non-negative, and the sum of weights must be greater than zero. + If an empty list is provided for weights, all values in `data` are treated as equally weighted. + + Returns: + The weighted geometric mean of the values. + If `data` is empty, or if weights are provided and have a different length than `data`, + or if any weight is negative, or if the total weight is zero, + or if any value in `data` is negative, the function returns NaN. + """ + var n = len(data) + if n == 0: + return nan[DType.float64]() + + var ws = len(weights) > 0 + if ws: + if len(weights) != n: + return nan[DType.float64]() + var total: Float64 = 0.0 + for w in weights: + if w < 0.0: + return nan[DType.float64]() + total += w + + if total == 0.0: + return nan[DType.float64]() + var log_sum: Float64 = 0.0 + for i in range(n): + if data[i] < 0.0: + return nan[DType.float64]() + log_sum += weights[i] * log(data[i]) + return exp(log_sum / total) + + var log_sum: Float64 = 0.0 + for i in range(n): + if data[i] < 0.0: + return nan[DType.float64]() + log_sum += log(data[i]) + return exp(log_sum / Float64(n)) + + +fn hmean(data: List[Float64], weights: List[Float64]) -> Float64: + """ + Compute the weighted harmonic mean of a list of values. + + The harmonic mean is defined as n / (Σ(1/xᵢ)) for n values. If weights are provided, + computes the weighted harmonic mean using the formula: + + Σwᵢ / Σ(wᵢ / xᵢ) + + where xᵢ are the data values and wᵢ are the corresponding weights. + + Args: + data: A list of positive values (must all be strictly greater than zero). + weights: A list of weights corresponding to each value in `data`. + If weights are provided, they must be the same length as `data`, + all weights must be non-negative, and the sum of weights must be greater than zero. + If an empty list is provided for weights, all values in `data` are treated as equally weighted. + + Returns: + The weighted harmonic mean of the values. + If `data` is empty, or if weights are provided and have a different length than `data`, + or if any weight is negative, or if the total weight is zero, + or if any value in `data` is zero or negative, the function returns NaN. + """ + var n = len(data) + if n == 0: + return nan[DType.float64]() + + var ws = len(weights) > 0 + if ws: + if len(weights) != n: + return nan[DType.float64]() + var total: Float64 = 0.0 + for w in weights: + if w < 0.0: + return nan[DType.float64]() + total += w + + if total == 0.0: + return nan[DType.float64]() + var inv_sum: Float64 = 0.0 + for i in range(n): + if data[i] <= 0.0: + return nan[DType.float64]() + inv_sum += weights[i] / data[i] + return total / inv_sum + + var inv_sum: Float64 = 0.0 + for i in range(n): + if data[i] <= 0.0: + return nan[DType.float64]() + inv_sum += 1.0 / data[i] + return Float64(n) / inv_sum diff --git a/tests/test_stats.mojo b/tests/test_stats.mojo index d5df865..96fb632 100644 --- a/tests/test_stats.mojo +++ b/tests/test_stats.mojo @@ -8,7 +8,7 @@ Covers mean, variance, std, median, quantile, skewness, and kurtosis with both analytical checks and scipy/numpy comparisons. """ -from math import sqrt +from math import sqrt, exp, log from python import Python, PythonObject from testing import assert_almost_equal, TestSuite @@ -22,6 +22,8 @@ from stamojo.stats import ( kurtosis, data_min, data_max, + gmean, + hmean, ) @@ -152,6 +154,76 @@ fn test_scipy_comparison() raises: print("⊘ test_scipy_comparison skipped (numpy not available)") +fn test_gmean() raises: + """Test geometric mean.""" + # first three test values are from scipy examples. + var data: List[Float64] = [1.0, 4.0] + var res = gmean(data, List[Float64]()) + assert_almost_equal(res, 2.0, atol=1e-12) + + var data2: List[Float64] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] + var res2 = gmean(data2, List[Float64]()) + assert_almost_equal(res2, 3.3800151591412964, atol=1e-12) + + var data3: List[Float64] = [1.0, 4.0, 7.0] + var weights3: List[Float64] = [3.0, 1.0, 3.0] + var res3 = gmean(data3, weights3) + assert_almost_equal(res3, 2.80668351922014, atol=1e-12) + + try: + var sp = Python.import_module("scipy.stats") + var data4: List[Float64] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] + var py_data4 = Python.list(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + var sp_gmean = _py_f64(sp.gmean(py_data4)) + var res4 = gmean(data4, List[Float64]()) + assert_almost_equal(res4, sp_gmean, atol=1e-12) + + var data5: List[Float64] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] + var weights5: List[Float64] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] + var py_data5 = Python.list(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + var py_weights5 = Python.list(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + var sp_gmean_w = _py_f64(sp.gmean(a=py_data5, weights=py_weights5)) + var res5 = gmean(data5, weights5) + assert_almost_equal(res5, sp_gmean_w, atol=1e-12) + except: + print("⊘ test_gmean scipy comparison skipped (scipy not available)") + + +fn test_hmean() raises: + """Test harmonic mean.""" + # first three test values are from scipy examples. + var data: List[Float64] = [1.0, 4.0] + var res = hmean(data, List[Float64]()) + assert_almost_equal(res, 1.6, atol=1e-12) + + var data2: List[Float64] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] + var res2 = hmean(data2, List[Float64]()) + assert_almost_equal(res2, 2.6997245179063363, atol=1e-12) + + var data3: List[Float64] = [1.0, 4.0, 7.0] + var weights3: List[Float64] = [3.0, 1.0, 3.0] + var res3 = hmean(data3, weights3) + assert_almost_equal(res3, 1.9029126213592233, atol=1e-12) + + try: + var sp = Python.import_module("scipy.stats") + var data4: List[Float64] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] + var py_data4 = Python.list(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + var sp_hmean = _py_f64(sp.hmean(py_data4)) + var res4 = hmean(data4, List[Float64]()) + assert_almost_equal(res4, sp_hmean, atol=1e-12) + + var data5: List[Float64] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] + var weights5: List[Float64] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] + var py_data5 = Python.list(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + var py_weights5 = Python.list(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + var sp_hmean_w = _py_f64(sp.hmean(a=py_data5, weights=py_weights5)) + var res5 = hmean(data5, weights5) + assert_almost_equal(res5, sp_hmean_w, atol=1e-12) + except: + print("⊘ test_hmean scipy comparison skipped (scipy not available)") + + # ===----------------------------------------------------------------------=== # # Main test runner # ===----------------------------------------------------------------------=== #