Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/stamojo/stats/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ from .descriptive import (
kurtosis,
data_min,
data_max,
gmean,
hmean,
)

from .correlation import (
Expand Down
116 changes: 115 additions & 1 deletion src/stamojo/stats/descriptive.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ Provides functions for computing summary statistics of ``List[Float64]`` data:
- ``kurtosis`` — (Excess) kurtosis (bias-corrected)
- ``data_min`` — Minimum value
- ``data_max`` — Maximum value
- ``gmean`` — Geometric mean
- ``hmean`` — Harmonic mean
"""

from math import sqrt, nan
from math import sqrt, nan, log, exp


# ===----------------------------------------------------------------------=== #
Expand Down Expand Up @@ -259,3 +261,115 @@ fn data_max(data: List[Float64]) -> Float64:
if data[i] > result:
result = data[i]
return result


# TODO: Due to limitation in mojo compiler in 0.26.1 (resolved in nightly), we can't have Optional[List].
# Once we have that, we can make weights optional and handle the unweighted case more cleanly.
# For now, we can just require an empty list for unweighted case.
fn gmean(data: List[Float64], weights: List[Float64]) -> Float64:
"""Compute the weighted geometric mean of a list of values.

The geometric mean is the nth root of the product of n values. If weights are provided,
computes the weighted geometric mean using the formula:

exp(Σ(wᵢ · log(xᵢ)) / Σwᵢ)

where xᵢ are the data values and wᵢ are the corresponding weights.

Args:
data: A list of values, which must all be non-negative.
weights: A list of weights corresponding to each value in `data`.
If weights are provided, they must be the same length as `data`,
all weights must be non-negative, and the sum of weights must be greater than zero.
If an empty list is provided for weights, all values in `data` are treated as equally weighted.

Returns:
The weighted geometric mean of the values.
If `data` is empty, or if weights are provided and have a different length than `data`,
or if any weight is negative, or if the total weight is zero,
or if any value in `data` is negative, the function returns NaN.
"""
var n = len(data)
if n == 0:
return nan[DType.float64]()

var ws = len(weights) > 0
if ws:
if len(weights) != n:
return nan[DType.float64]()
var total: Float64 = 0.0
for w in weights:
if w < 0.0:
return nan[DType.float64]()
total += w

if total == 0.0:
return nan[DType.float64]()
var log_sum: Float64 = 0.0
for i in range(n):
if data[i] < 0.0:
return nan[DType.float64]()
log_sum += weights[i] * log(data[i])
return exp(log_sum / total)

var log_sum: Float64 = 0.0
for i in range(n):
if data[i] < 0.0:
return nan[DType.float64]()
log_sum += log(data[i])
return exp(log_sum / Float64(n))


fn hmean(data: List[Float64], weights: List[Float64]) -> Float64:
"""
Compute the weighted harmonic mean of a list of values.

The harmonic mean is defined as n / (Σ(1/xᵢ)) for n values. If weights are provided,
computes the weighted harmonic mean using the formula:

Σwᵢ / Σ(wᵢ / xᵢ)

where xᵢ are the data values and wᵢ are the corresponding weights.

Args:
data: A list of positive values (must all be strictly greater than zero).
weights: A list of weights corresponding to each value in `data`.
If weights are provided, they must be the same length as `data`,
all weights must be non-negative, and the sum of weights must be greater than zero.
If an empty list is provided for weights, all values in `data` are treated as equally weighted.

Returns:
The weighted harmonic mean of the values.
If `data` is empty, or if weights are provided and have a different length than `data`,
or if any weight is negative, or if the total weight is zero,
or if any value in `data` is zero or negative, the function returns NaN.
"""
var n = len(data)
if n == 0:
return nan[DType.float64]()

var ws = len(weights) > 0
if ws:
if len(weights) != n:
return nan[DType.float64]()
var total: Float64 = 0.0
for w in weights:
if w < 0.0:
return nan[DType.float64]()
total += w

if total == 0.0:
return nan[DType.float64]()
var inv_sum: Float64 = 0.0
for i in range(n):
if data[i] <= 0.0:
return nan[DType.float64]()
inv_sum += weights[i] / data[i]
return total / inv_sum

var inv_sum: Float64 = 0.0
for i in range(n):
if data[i] <= 0.0:
return nan[DType.float64]()
inv_sum += 1.0 / data[i]
return Float64(n) / inv_sum
74 changes: 73 additions & 1 deletion tests/test_stats.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Covers mean, variance, std, median, quantile, skewness, and kurtosis
with both analytical checks and scipy/numpy comparisons.
"""

from math import sqrt
from math import sqrt, exp, log
from python import Python, PythonObject
from testing import assert_almost_equal, TestSuite

Expand All @@ -22,6 +22,8 @@ from stamojo.stats import (
kurtosis,
data_min,
data_max,
gmean,
hmean,
)


Expand Down Expand Up @@ -152,6 +154,76 @@ fn test_scipy_comparison() raises:
print("⊘ test_scipy_comparison skipped (numpy not available)")


fn test_gmean() raises:
"""Test geometric mean."""
# first three test values are from scipy examples.
var data: List[Float64] = [1.0, 4.0]
var res = gmean(data, List[Float64]())
assert_almost_equal(res, 2.0, atol=1e-12)

var data2: List[Float64] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
var res2 = gmean(data2, List[Float64]())
assert_almost_equal(res2, 3.3800151591412964, atol=1e-12)

var data3: List[Float64] = [1.0, 4.0, 7.0]
var weights3: List[Float64] = [3.0, 1.0, 3.0]
var res3 = gmean(data3, weights3)
assert_almost_equal(res3, 2.80668351922014, atol=1e-12)

try:
var sp = Python.import_module("scipy.stats")
var data4: List[Float64] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
var py_data4 = Python.list(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0)
var sp_gmean = _py_f64(sp.gmean(py_data4))
var res4 = gmean(data4, List[Float64]())
assert_almost_equal(res4, sp_gmean, atol=1e-12)

var data5: List[Float64] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
var weights5: List[Float64] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
var py_data5 = Python.list(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0)
var py_weights5 = Python.list(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0)
var sp_gmean_w = _py_f64(sp.gmean(a=py_data5, weights=py_weights5))
var res5 = gmean(data5, weights5)
assert_almost_equal(res5, sp_gmean_w, atol=1e-12)
except:
print("⊘ test_gmean scipy comparison skipped (scipy not available)")


fn test_hmean() raises:
"""Test harmonic mean."""
# first three test values are from scipy examples.
var data: List[Float64] = [1.0, 4.0]
var res = hmean(data, List[Float64]())
assert_almost_equal(res, 1.6, atol=1e-12)

var data2: List[Float64] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
var res2 = hmean(data2, List[Float64]())
assert_almost_equal(res2, 2.6997245179063363, atol=1e-12)

var data3: List[Float64] = [1.0, 4.0, 7.0]
var weights3: List[Float64] = [3.0, 1.0, 3.0]
var res3 = hmean(data3, weights3)
assert_almost_equal(res3, 1.9029126213592233, atol=1e-12)

try:
var sp = Python.import_module("scipy.stats")
var data4: List[Float64] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
var py_data4 = Python.list(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0)
var sp_hmean = _py_f64(sp.hmean(py_data4))
var res4 = hmean(data4, List[Float64]())
assert_almost_equal(res4, sp_hmean, atol=1e-12)

var data5: List[Float64] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
var weights5: List[Float64] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
var py_data5 = Python.list(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0)
var py_weights5 = Python.list(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0)
var sp_hmean_w = _py_f64(sp.hmean(a=py_data5, weights=py_weights5))
var res5 = hmean(data5, weights5)
assert_almost_equal(res5, sp_hmean_w, atol=1e-12)
except:
print("⊘ test_hmean scipy comparison skipped (scipy not available)")


# ===----------------------------------------------------------------------=== #
# Main test runner
# ===----------------------------------------------------------------------=== #
Expand Down