Skip to content

Commit 512746d

Browse files
[distribution][mojo][doc] Implement Beta, Gamma, Poisson distributions + Update codebase to Mojo v0.26.2 + Update docstrings (#5)
This PR does the following: - Update the codebase of StaMojo to Mojo v.0.26.2 - Add Beta, Gamma and Poisson distributions. - change `std` to `stddev` in descriptive.mojo to avoid conflicts with mojo default stdlib namespace. - Update docstrings to be compliant with the Mojo docstring convention. --------- Co-authored-by: ZHU Yuhao <dr.yuhao.zhu@outlook.com>
1 parent 9668df9 commit 512746d

26 files changed

Lines changed: 2093 additions & 679 deletions

.github/workflows/run_tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
strategy:
1717
fail-fast: false
1818
matrix:
19-
os: ["ubuntu-22.04"]
19+
os: ["macos-latest"]
2020

2121
runs-on: ${{ matrix.os }}
2222
timeout-minutes: 30

pixi.lock

Lines changed: 518 additions & 416 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pixi.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[workspace]
22
authors = ["MojoMath <https://github.com/mojomath>"]
3-
channels = ["https://repo.prefix.dev/modular-community", "https://conda.modular.com/max-nightly", "https://conda.modular.com/max", "conda-forge"]
3+
channels = ["https://repo.prefix.dev/modular-community", "https://conda.modular.com/max", "conda-forge", "https://conda.modular.com/max-nightly"]
44
description = "A statistical computing library for Mojo, inspired by scipy.stats and statsmodels"
55
license = "Apache-2.0"
66
name = "stamojo"
@@ -9,7 +9,7 @@ readme = "README.md"
99
version = "0.1.0"
1010

1111
[dependencies]
12-
mojo = "==0.26.1"
12+
mojo = ">=0.26.2.0,<0.27"
1313

1414
# ── Feature: test (adds scipy for reference-value benchmarks) ────────────────
1515
[feature.test.dependencies]
@@ -37,4 +37,4 @@ c = "clear && pixi run clean"
3737
clean = "rm -f tests/stamojo.mojopkg"
3838

3939
# doc
40-
doc = "clear && pixi run mojo doc -o docs/doc.json --diagnose-missing-doc-strings --validate-doc-strings src/stamojo"
40+
doc = "pixi run mojo doc --diagnose-missing-doc-strings src/stamojo > /dev/null"

src/stamojo/distributions/__init__.mojo

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ Distributions provided:
1414
- `FDist` — F-distribution (Fisher-Snedecor)
1515
- `Exponential` — Exponential distribution
1616
- `Binomial` — Binomial distribution
17+
- `Gamma` — Gamma distribution
18+
- `Beta` — Beta distribution
19+
- `Poisson` — Poisson distribution
1720
"""
1821

1922
from .normal import Normal
@@ -22,3 +25,6 @@ from .chi2 import ChiSquared
2225
from .f import FDist
2326
from .exponential import Exponential
2427
from .binomial import Binomial
28+
from .gamma import Gamma
29+
from .beta import Beta
30+
from .poisson import Poisson
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
# ===----------------------------------------------------------------------=== #
2+
# Stamojo - Distributions - Beta distribution
3+
# Licensed under Apache 2.0
4+
# ===----------------------------------------------------------------------=== #
5+
"""Beta distribution.
6+
7+
Provides the `Beta` distribution struct with PDF, log-PDF, CDF,
8+
survival function, and percent-point function (PPF / quantile).
9+
10+
The beta distribution with shape parameters *a* and *b* has PDF::
11+
12+
f(x; a, b) = x^{a-1} (1-x)^{b-1} / B(a, b), 0 < x < 1
13+
"""
14+
15+
from std.math import sqrt, log, lgamma, exp, nan, inf
16+
17+
from stamojo.special import betainc, lbeta, ndtri
18+
19+
20+
# ===----------------------------------------------------------------------=== #
21+
# Constants
22+
# ===----------------------------------------------------------------------=== #
23+
24+
comptime _EPS = 1.0e-12
25+
comptime _MAX_ITER = 100
26+
27+
28+
# ===----------------------------------------------------------------------=== #
29+
# Beta distribution
30+
# ===----------------------------------------------------------------------=== #
31+
32+
33+
@fieldwise_init
34+
struct Beta(Copyable, Movable):
35+
"""Beta distribution with shape parameters `a` and `b`.
36+
37+
Fields:
38+
a: First shape parameter. Must be positive.
39+
b: Second shape parameter. Must be positive.
40+
"""
41+
42+
var a: Float64
43+
"""First shape parameter. Must be positive."""
44+
45+
var b: Float64
46+
"""Second shape parameter. Must be positive."""
47+
48+
# --- Density functions ---------------------------------------------------
49+
50+
def pdf(self, x: Float64) -> Float64:
51+
"""Computes the probability density function at *x*.
52+
53+
Args:
54+
x: Point at which to evaluate the PDF.
55+
56+
Returns:
57+
The PDF value at *x*.
58+
"""
59+
if x <= 0.0 or x >= 1.0:
60+
return 0.0
61+
return exp(self.logpdf(x))
62+
63+
def logpdf(self, x: Float64) -> Float64:
64+
"""Computes the natural logarithm of the PDF at *x*.
65+
66+
Args:
67+
x: Point at which to evaluate the log-PDF.
68+
69+
Returns:
70+
The log-PDF value at *x*.
71+
"""
72+
if x <= 0.0 or x >= 1.0:
73+
return -inf[DType.float64]()
74+
return (
75+
(self.a - 1.0) * log(x)
76+
+ (self.b - 1.0) * log(1.0 - x)
77+
- lbeta(self.a, self.b)
78+
)
79+
80+
# --- Distribution functions ----------------------------------------------
81+
82+
def cdf(self, x: Float64) -> Float64:
83+
"""Computes the cumulative distribution function P(X ≤ x).
84+
85+
CDF(x; a, b) = I_x(a, b) (regularized incomplete beta).
86+
87+
Args:
88+
x: Point at which to evaluate the CDF.
89+
90+
Returns:
91+
The CDF value at *x*.
92+
"""
93+
if x <= 0.0:
94+
return 0.0
95+
if x >= 1.0:
96+
return 1.0
97+
return betainc(self.a, self.b, x)
98+
99+
def logcdf(self, x: Float64) -> Float64:
100+
"""Computes the natural logarithm of the CDF at *x*.
101+
102+
Args:
103+
x: Point at which to evaluate the log-CDF.
104+
105+
Returns:
106+
The log-CDF value at *x*.
107+
"""
108+
if x <= 0.0:
109+
return -inf[DType.float64]()
110+
if x >= 1.0:
111+
return 0.0
112+
var c = self.cdf(x)
113+
if c <= 0.0:
114+
return -inf[DType.float64]()
115+
return log(c)
116+
117+
def sf(self, x: Float64) -> Float64:
118+
"""Computes the survival function (1 − CDF) at *x*.
119+
120+
Args:
121+
x: Point at which to evaluate the survival function.
122+
123+
Returns:
124+
The survival function value at *x*.
125+
"""
126+
if x <= 0.0:
127+
return 1.0
128+
if x >= 1.0:
129+
return 0.0
130+
return 1.0 - self.cdf(x)
131+
132+
def logsf(self, x: Float64) -> Float64:
133+
"""Computes the natural logarithm of the survival function at *x*.
134+
135+
Args:
136+
x: Point at which to evaluate the log-SF.
137+
138+
Returns:
139+
The log-SF value at *x*.
140+
"""
141+
if x <= 0.0:
142+
return 0.0
143+
if x >= 1.0:
144+
return -inf[DType.float64]()
145+
var s = self.sf(x)
146+
if s <= 0.0:
147+
return -inf[DType.float64]()
148+
return log(s)
149+
150+
def ppf(self, p: Float64) -> Float64:
151+
"""Computes the percent-point function (quantile / inverse CDF).
152+
153+
Uses Newton-Raphson with bisection fallback.
154+
155+
Args:
156+
p: Probability value in [0, 1].
157+
158+
Returns:
159+
The quantile corresponding to *p*.
160+
"""
161+
if p < 0.0 or p > 1.0:
162+
return nan[DType.float64]()
163+
if p == 0.0:
164+
return 0.0
165+
if p == 1.0:
166+
return 1.0
167+
168+
var mu = self.a / (self.a + self.b)
169+
var x: Float64
170+
if self.a > 1.0 and self.b > 1.0:
171+
var sigma = sqrt(
172+
self.a
173+
* self.b
174+
/ ((self.a + self.b) ** 2 * (self.a + self.b + 1.0))
175+
)
176+
x = mu + sigma * ndtri(p)
177+
if x <= 0.0:
178+
x = 0.01
179+
if x >= 1.0:
180+
x = 0.99
181+
else:
182+
x = mu
183+
184+
# Newton-Raphson with bisection fallback.
185+
var lo = 0.0
186+
var hi = 1.0
187+
188+
for _ in range(_MAX_ITER):
189+
var f = self.cdf(x) - p
190+
if abs(f) < _EPS:
191+
return x
192+
193+
var fp = self.pdf(x)
194+
if fp > 1.0e-300:
195+
var x_new = x - f / fp
196+
if f > 0.0:
197+
hi = x
198+
else:
199+
lo = x
200+
if x_new <= lo or x_new >= hi:
201+
x = (lo + hi) / 2.0
202+
else:
203+
x = x_new
204+
else:
205+
if f > 0.0:
206+
hi = x
207+
else:
208+
lo = x
209+
x = (lo + hi) / 2.0
210+
211+
return x
212+
213+
def isf(self, q: Float64) -> Float64:
214+
"""Computes the inverse survival function (inverse SF).
215+
216+
Args:
217+
q: Probability in [0, 1].
218+
219+
Returns:
220+
The value *x* such that SF(x) = *q*.
221+
"""
222+
return self.ppf(1.0 - q)
223+
224+
# --- Summary statistics --------------------------------------------------
225+
226+
def median(self) -> Float64:
227+
"""Computes the median of the distribution (approximation).
228+
229+
Uses the approximation: (a - 1/3) / (a + b - 2/3) for a, b >= 1.
230+
231+
Returns:
232+
The median of the distribution.
233+
"""
234+
if self.a >= 1.0 and self.b >= 1.0:
235+
return (self.a - 1.0 / 3.0) / (self.a + self.b - 2.0 / 3.0)
236+
return self.a / (self.a + self.b)
237+
238+
def mean(self) -> Float64:
239+
"""Computes the distribution mean = a / (a + b).
240+
241+
Returns:
242+
The mean of the distribution.
243+
"""
244+
return self.a / (self.a + self.b)
245+
246+
def variance(self) -> Float64:
247+
"""Computes the distribution variance = ab / ((a+b)²(a+b+1)).
248+
249+
Returns:
250+
The variance of the distribution.
251+
"""
252+
var ab = self.a + self.b
253+
return self.a * self.b / (ab * ab * (ab + 1.0))
254+
255+
def std(self) -> Float64:
256+
"""Computes the distribution standard deviation.
257+
258+
Returns:
259+
The standard deviation of the distribution.
260+
"""
261+
return sqrt(self.variance())
262+
263+
def entropy(self) -> Float64:
264+
"""Computes the differential entropy of the distribution.
265+
266+
H = ln(B(a,b)) - (a-1)ψ(a) - (b-1)ψ(b) + (a+b-2)ψ(a+b)
267+
Using digamma approximation: ψ(x) ≈ ln(x) - 1/(2x) - 1/(12x²)
268+
269+
Returns:
270+
The differential entropy.
271+
"""
272+
var digamma_a = (
273+
log(self.a) - 1.0 / (2.0 * self.a) - 1.0 / (12.0 * self.a * self.a)
274+
)
275+
var digamma_b = (
276+
log(self.b) - 1.0 / (2.0 * self.b) - 1.0 / (12.0 * self.b * self.b)
277+
)
278+
var ab = self.a + self.b
279+
var digamma_ab = log(ab) - 1.0 / (2.0 * ab) - 1.0 / (12.0 * ab * ab)
280+
return (
281+
lbeta(self.a, self.b)
282+
- (self.a - 1.0) * digamma_a
283+
- (self.b - 1.0) * digamma_b
284+
+ (self.a + self.b - 2.0) * digamma_ab
285+
)

0 commit comments

Comments
 (0)