From 9e80e0b8fbee7e30bb06ce7fe73a1199c55c316b Mon Sep 17 00:00:00 2001 From: shivasankar Date: Thu, 5 Mar 2026 12:32:06 +0900 Subject: [PATCH 1/8] add DiscretelyDistributed trait --- src/stamojo/distributions/traits.mojo | 57 +++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/src/stamojo/distributions/traits.mojo b/src/stamojo/distributions/traits.mojo index a68aabe..d6505cc 100644 --- a/src/stamojo/distributions/traits.mojo +++ b/src/stamojo/distributions/traits.mojo @@ -61,3 +61,60 @@ trait ContinuouslyDistributed(Copyable, Movable): fn std(self) -> Float64: """Standard deviation of the distribution.""" ... + +trait DiscretelyDistributed(Copyable, Movable): + """Trait for discrete probability distributions.""" + + # --- Probability mass functions ------------------------------------------ + + fn pmf(self, k: Int64) -> Float64: + """Probability mass function at *k*.""" + ... + + fn logpmf(self, k: Int64) -> Float64: + """Natural logarithm of the probability mass function at *k*.""" + ... + + # --- Distribution functions ---------------------------------------------- + + fn cdf(self, k: Int64) -> Float64: + """Cumulative distribution function P(X ≤ k).""" + ... + + fn logcdf(self, k: Int64) -> Float64: + """Natural logarithm of the cumulative distribution function at *k*.""" + ... + + fn sf(self, k: Int64) -> Float64: + """Survival function (1 − CDF) at *k*.""" + ... + + fn logsf(self, k: Int64) -> Float64: + """Natural logarithm of the survival function at *k*.""" + ... + + fn ppf(self, q: Float64) -> Int64: + """Percent point function (inverse of CDF) at *q*.""" + ... + + fn isf(self, q: Float64) -> Int64: + """Inverse survival function (inverse of SF) at *q*.""" + ... + + # --- Statistical properties ---------------------------------------------- + + fn median(self) -> Int64: + """Median of the distribution.""" + ... + + fn mean(self) -> Float64: + """Mean of the distribution.""" + ... + + fn variance(self) -> Float64: + """Variance of the distribution.""" + ... + + fn std(self) -> Float64: + """Standard deviation of the distribution.""" + ... \ No newline at end of file From 1e4c7d90dfdd8eea2f5abecfff70ab95ccea3f86 Mon Sep 17 00:00:00 2001 From: shivasankar Date: Thu, 5 Mar 2026 15:14:34 +0900 Subject: [PATCH 2/8] Create binomial.mojo --- src/stamojo/distributions/binomial.mojo | 249 ++++++++++++++++++++++++ 1 file changed, 249 insertions(+) create mode 100644 src/stamojo/distributions/binomial.mojo diff --git a/src/stamojo/distributions/binomial.mojo b/src/stamojo/distributions/binomial.mojo new file mode 100644 index 0000000..115be72 --- /dev/null +++ b/src/stamojo/distributions/binomial.mojo @@ -0,0 +1,249 @@ +# ===----------------------------------------------------------------------=== # +# Stamojo - Distributions - Binomial distribution +# Licensed under Apache 2.0 +# ===----------------------------------------------------------------------=== # +"""Binomial distribution. + +Provides the `Binomial` distribution struct with PMF, log-PMF, CDF, +survival function, and percent-point function (PPF / quantile). + +The binomial distribution with parameters n and p has PMF: + + P(X = k) = C(n, k) * p^k * (1-p)^(n-k), k = 0, 1, ..., n + +where C(n, k) is the binomial coefficient. +""" + +from math import log, log1p, exp, lgamma, nan, inf, floor, sqrt + +from stamojo.distributions.traits import DiscretelyDistributed + + +# `DiscretelyDistributed` trait contains `Copyable` and `Movable` traits +struct Binomial(DiscretelyDistributed): + """Binomial distribution. + + Represents the binomial distribution, a discrete probability distribution + that models the number of successes in a fixed number of independent + Bernoulli trials, each with the same probability of success. + + The probability mass function (PMF) for the binomial distribution is: + + P(X = k) = C(n, k) * p^k * (1-p)^(n-k) + + where C(n, k) is the binomial coefficient. + """ + + var n: UInt + """Number of trials (must be >= 0).""" + + var p: Float64 + """Probability of success in each trial (must be in [0, 1]).""" + + # --- Initialization ------------------------------------------------------- + + fn __init__(out self, n: UInt, p: Float64) raises: + self.n = n + self.p = p + + # --- Probability functions ------------------------------------------------ + + fn pmf(self, k: Int) -> Float64: + """Probability mass function at *k*. + + Args: + k: Number of successes (must be in [0, n]). + + Returns: + Returns 0.0 if k > n. + """ + return exp(self.logpmf(k)) + + fn logpmf(self, k: Int) -> Float64: + """Natural logarithm of the PMF at *k*. + + Args: + k: Number of successes (must be in [0, n]). + + Returns: + Returns -∞ if k > n. + """ + if k > Int(self.n): + return -inf[DType.float64]() + + if self.p == 0.0: + return 0.0 if k == 0 else -inf[DType.float64]() + + if self.p == 1.0: + return 0.0 if k == Int(self.n) else -inf[DType.float64]() + + var kf = Float64(k) + var nf = Float64(self.n) + var logc = _log_binomial_coefficient(self.n, k) + return logc + kf * log(self.p) + (nf - kf) * log1p(-self.p) + + fn cdf(self, k: Int) -> Float64: + """Cumulative distribution function P(X ≤ k). + + Args: + k: Number of successes (must be in [0, n]). + + Returns: + CDF value at *k*. Returns 1.0 for k >= n. + """ + if k >= Int(self.n): + return 1.0 + + if self.p == 0.0: + return 1.0 + if self.p == 1.0: + return 0.0 if k < Int(self.n) else 1.0 + + var nf = Float64(self.n) + var q = 1.0 - self.p + + var pmf_k = exp(nf * log(q)) # k = 0 + var total = pmf_k + + for i in range(0, k): + pmf_k *= (nf - Float64(i)) / Float64(i + 1) * (self.p / q) + total += pmf_k + + return total + + fn logcdf(self, k: Int) -> Float64: + """Natural logarithm of the CDF at *k*. + + Args: + k: Number of successes (must be in [0, n]). + + Returns: + Log-CDF value at *k*. Returns 0.0 for k >= n. + """ + var c = self.cdf(k) + if c <= 0.0: + return -inf[DType.float64]() + return log(c) + + fn sf(self, k: Int) -> Float64: + """Survival function (1 − CDF) at *k*. + + Args: + k: Number of successes (must be in [0, n]). + + Returns: + Survival function value at *k*. Returns 0.0 for k >= n. + """ + return 1.0 - self.cdf(k) + + fn logsf(self, k: Int) -> Float64: + """Natural logarithm of the survival function at *k*. + + Args: + k: Number of successes (must be in [0, n]). + + Returns: + Log-survival function value at *k*. + """ + if self.p < 0.0 or self.p > 1.0: + return nan[DType.float64]() + return log1p(-self.cdf(k)) + + fn ppf(self, q: Float64) -> Int: + """Percent point function (inverse CDF). + + Args: + q: Probability in [0, 1]. + + Returns: + Smallest integer k such that CDF(k) ≥ q. Returns 0 for q=0, n for q=1. + """ + if q == 0.0: + return 0 + if q == 1.0: + return Int(self.n) + + var cumulative: Float64 = 0.0 + for k in range(self.n + 1): + cumulative += self.pmf(Int(k)) + if cumulative >= q: + return Int(k) + + return Int(self.n) + + fn isf(self, q: Float64) -> Int: + """Inverse survival function (inverse SF). + + Args: + q: Probability in [0, 1]. + + Returns: + Smallest integer k such that SF(k) ≤ q. Returns n for q=0, 0 for q=1. + """ + if q == 0.0: + return Int(self.n) + if q == 1.0: + return 0 + + var cumulative = 0.0 + for k in range(self.n + 1): + cumulative += self.pmf(Int(k)) + if 1.0 - cumulative <= q: + return Int(k) + + return Int(self.n) + + # --- Summary statistics -------------------------------------------------- + fn median(self) -> UInt: + """Median of the distribution. + + Returns: + The median of the distribution. + """ + return UInt(floor(Float64(self.n) * self.p + 0.5)) + + fn mean(self) -> Float64: + """Distribution mean: n * p. + + Returns: + The mean of the distribution. + """ + return Float64(self.n) * self.p + + fn variance(self) -> Float64: + """Distribution variance: n * p * (1 - p). + + Returns: + The variance of the distribution. + """ + var np = Float64(self.n) * self.p + return np * (1.0 - self.p) + + fn std(self) -> Float64: + """Distribution standard deviation. + + Returns: + The standard deviation of the distribution. + """ + return sqrt(self.variance()) + + +# ===----------------------------------------------------------------------=== # +# Helper functions +# ===----------------------------------------------------------------------=== # + + +fn _log_binomial_coefficient(n: UInt, k: Int) -> Float64: + """Log of the binomial coefficient C(n, k). + + Args: + n: Number of trials. + k: Number of successes. + + Returns: + log(C(n, k)). + """ + var nf = Float64(n) + var kf = Float64(k) + var fnk = Float64(n - k) + return lgamma(nf + 1.0) - lgamma(kf + 1.0) - lgamma(fnk + 1.0) From d4e7da321f58e9ccc54a8a6de56c5c3f758dff0a Mon Sep 17 00:00:00 2001 From: shivasankar Date: Thu, 5 Mar 2026 15:14:41 +0900 Subject: [PATCH 3/8] add tests for binomial --- tests/test_distributions.mojo | 100 ++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/tests/test_distributions.mojo b/tests/test_distributions.mojo index e0272b4..4e404fa 100644 --- a/tests/test_distributions.mojo +++ b/tests/test_distributions.mojo @@ -22,6 +22,7 @@ from stamojo.distributions import ( ChiSquared, FDist, Exponential, + Binomial, ) @@ -488,6 +489,105 @@ fn test_expon_scipy() raises: assert_almost_equal(e2.cdf(x), sp_cdf2, atol=1e-12) +# ===----------------------------------------------------------------------=== # +# Binomial distribution tests +# ===----------------------------------------------------------------------=== # + + +fn test_binomial_pmf_basic() raises: + """Test Binomial PMF at known values.""" + var b = Binomial(10, 0.5) + assert_almost_equal(b.pmf(0), 1.0 / 1024.0, atol=1e-12) + assert_almost_equal(b.pmf(5), 252.0 / 1024.0, atol=1e-12) + assert_almost_equal(b.pmf(10), 1.0 / 1024.0, atol=1e-12) + assert_almost_equal(b.pmf(11), 0.0, atol=1e-12) + + +fn test_binomial_cdf_sf() raises: + """Test Binomial CDF and SF at known values.""" + var b = Binomial(4, 0.5) + assert_almost_equal(b.cdf(2), 0.6875, atol=1e-12) + assert_almost_equal(b.sf(2), 0.3125, atol=1e-12) + assert_almost_equal(b.cdf(4), 1.0, atol=1e-12) + + +fn test_binomial_edge_p() raises: + """Test Binomial behavior for p=0 and p=1.""" + var b0 = Binomial(5, 0.0) + assert_almost_equal(b0.pmf(0), 1.0, atol=1e-12) + assert_almost_equal(b0.pmf(1), 0.0, atol=1e-12) + assert_almost_equal(b0.cdf(0), 1.0, atol=1e-12) + assert_almost_equal(b0.sf(0), 0.0, atol=1e-12) + + var b1 = Binomial(5, 1.0) + assert_almost_equal(b1.pmf(5), 1.0, atol=1e-12) + assert_almost_equal(b1.pmf(4), 0.0, atol=1e-12) + assert_almost_equal(b1.cdf(4), 0.0, atol=1e-12) + assert_almost_equal(b1.cdf(5), 1.0, atol=1e-12) + + +fn test_binomial_logpmf() raises: + """Test Binomial log-PMF consistency.""" + var b = Binomial(6, 0.3) + var k = 2 + assert_almost_equal(b.logpmf(k), log(b.pmf(k)), atol=1e-12) + + +fn test_binomial_symmetry_p_half() raises: + """Test Binomial symmetry for p=0.5.""" + var b = Binomial(10, 0.5) + for k in range(0, 11): + assert_almost_equal(b.pmf(k), b.pmf(10 - k), atol=1e-12) + + +fn test_binomial_ppf_isf_roundtrip() raises: + """Test Binomial PPF/ISF consistency with CDF/SF.""" + var b = Binomial(12, 0.4) + var qs: List[Float64] = [0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99] + + for i in range(len(qs)): + var q = qs[i] + var k = b.ppf(q) + if b.cdf(k) < q: + raise Error("binomial ppf: cdf(k) < q") + if k > 0 and b.cdf(k - 1) >= q: + raise Error("binomial ppf: cdf(k-1) >= q") + + var k2 = b.isf(q) + if b.sf(k2) > q: + raise Error("binomial isf: sf(k) > q") + if k2 > 0 and b.sf(k2 - 1) <= q: + raise Error("binomial isf: sf(k-1) <= q") + + +fn test_binomial_scipy() raises: + """Test Binomial distribution against scipy.stats.binom.""" + var sp = _load_scipy_stats() + if sp is None: + print("test_binomial_scipy skipped (scipy not available)") + return + + var n: UInt = 10 + var p = 0.3 + var b = Binomial(n, p) + var ks: List[Int] = [0, 1, 2, 5, 10] + + for i in range(len(ks)): + var k = ks[i] + var sp_pmf = _py_f64(sp.binom.pmf(k, n, p)) + var sp_cdf = _py_f64(sp.binom.cdf(k, n, p)) + var sp_sf = _py_f64(sp.binom.sf(k, n, p)) + assert_almost_equal(b.pmf(k), sp_pmf, atol=1e-10) + assert_almost_equal(b.cdf(k), sp_cdf, atol=1e-10) + assert_almost_equal(b.sf(k), sp_sf, atol=1e-10) + + var qs: List[Float64] = [0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99] + for i in range(len(qs)): + var q = qs[i] + var sp_ppf = _py_f64(sp.binom.ppf(q, n, p)) + assert_almost_equal(Float64(b.ppf(q)), sp_ppf, atol=1e-10) + + # ===----------------------------------------------------------------------=== # # Main test runner # ===----------------------------------------------------------------------=== # From d10cf1bf584a93778745441bd983351377c0a1c6 Mon Sep 17 00:00:00 2001 From: shivasankar Date: Thu, 5 Mar 2026 15:15:03 +0900 Subject: [PATCH 4/8] addd `DiscretelyDistributed` --- src/stamojo/distributions/traits.mojo | 53 ++++++++++++++------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/src/stamojo/distributions/traits.mojo b/src/stamojo/distributions/traits.mojo index d6505cc..892281f 100644 --- a/src/stamojo/distributions/traits.mojo +++ b/src/stamojo/distributions/traits.mojo @@ -62,59 +62,60 @@ trait ContinuouslyDistributed(Copyable, Movable): """Standard deviation of the distribution.""" ... + trait DiscretelyDistributed(Copyable, Movable): """Trait for discrete probability distributions.""" - + # --- Probability mass functions ------------------------------------------ - - fn pmf(self, k: Int64) -> Float64: + + fn pmf(self, k: Int) -> Float64: """Probability mass function at *k*.""" ... - - fn logpmf(self, k: Int64) -> Float64: + + fn logpmf(self, k: Int) -> Float64: """Natural logarithm of the probability mass function at *k*.""" ... - + # --- Distribution functions ---------------------------------------------- - - fn cdf(self, k: Int64) -> Float64: + + fn cdf(self, k: Int) -> Float64: """Cumulative distribution function P(X ≤ k).""" ... - - fn logcdf(self, k: Int64) -> Float64: + + fn logcdf(self, k: Int) -> Float64: """Natural logarithm of the cumulative distribution function at *k*.""" ... - - fn sf(self, k: Int64) -> Float64: + + fn sf(self, k: Int) -> Float64: """Survival function (1 − CDF) at *k*.""" ... - - fn logsf(self, k: Int64) -> Float64: + + fn logsf(self, k: Int) -> Float64: """Natural logarithm of the survival function at *k*.""" ... - - fn ppf(self, q: Float64) -> Int64: + + fn ppf(self, q: Float64) -> Int: """Percent point function (inverse of CDF) at *q*.""" ... - - fn isf(self, q: Float64) -> Int64: + + fn isf(self, q: Float64) -> Int: """Inverse survival function (inverse of SF) at *q*.""" ... - - # --- Statistical properties ---------------------------------------------- - - fn median(self) -> Int64: + + # --- Statistical properties ---------------------------------------------- + + fn median(self) -> UInt: """Median of the distribution.""" ... - + fn mean(self) -> Float64: """Mean of the distribution.""" ... - + fn variance(self) -> Float64: """Variance of the distribution.""" ... - + fn std(self) -> Float64: """Standard deviation of the distribution.""" - ... \ No newline at end of file + ... From 6d945c243a87bc8cc33a97d3d899e662f3ab5cc1 Mon Sep 17 00:00:00 2001 From: shivasankar Date: Thu, 5 Mar 2026 15:15:15 +0900 Subject: [PATCH 5/8] update imports and docstring --- src/stamojo/distributions/__init__.mojo | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/stamojo/distributions/__init__.mojo b/src/stamojo/distributions/__init__.mojo index 56148d5..82d1ee5 100644 --- a/src/stamojo/distributions/__init__.mojo +++ b/src/stamojo/distributions/__init__.mojo @@ -13,6 +13,7 @@ Distributions provided: - `ChiSquared` — Chi-squared distribution - `FDist` — F-distribution (Fisher-Snedecor) - `Exponential` — Exponential distribution +- `Binomial` — Binomial distribution """ from .normal import Normal @@ -20,3 +21,4 @@ from .t import StudentT from .chi2 import ChiSquared from .f import FDist from .exponential import Exponential +from .binomial import Binomial From eefddc14cefaa81712eed9c88e942e946a22a49b Mon Sep 17 00:00:00 2001 From: shivasankar Date: Thu, 5 Mar 2026 15:23:43 +0900 Subject: [PATCH 6/8] remove raises in init --- src/stamojo/distributions/binomial.mojo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stamojo/distributions/binomial.mojo b/src/stamojo/distributions/binomial.mojo index 115be72..e8d4650 100644 --- a/src/stamojo/distributions/binomial.mojo +++ b/src/stamojo/distributions/binomial.mojo @@ -42,7 +42,7 @@ struct Binomial(DiscretelyDistributed): # --- Initialization ------------------------------------------------------- - fn __init__(out self, n: UInt, p: Float64) raises: + fn __init__(out self, n: UInt, p: Float64): self.n = n self.p = p From e843cd38accbc077e57308b4402ba218f62bc2bf Mon Sep 17 00:00:00 2001 From: shivasankar Date: Thu, 5 Mar 2026 15:33:55 +0900 Subject: [PATCH 7/8] change UInt to Int --- src/stamojo/distributions/binomial.mojo | 26 ++++++++++++------------- tests/test_distributions.mojo | 2 +- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/stamojo/distributions/binomial.mojo b/src/stamojo/distributions/binomial.mojo index e8d4650..0664582 100644 --- a/src/stamojo/distributions/binomial.mojo +++ b/src/stamojo/distributions/binomial.mojo @@ -34,7 +34,7 @@ struct Binomial(DiscretelyDistributed): where C(n, k) is the binomial coefficient. """ - var n: UInt + var n: Int """Number of trials (must be >= 0).""" var p: Float64 @@ -42,7 +42,7 @@ struct Binomial(DiscretelyDistributed): # --- Initialization ------------------------------------------------------- - fn __init__(out self, n: UInt, p: Float64): + fn __init__(out self, n: Int, p: Float64): self.n = n self.p = p @@ -68,14 +68,14 @@ struct Binomial(DiscretelyDistributed): Returns: Returns -∞ if k > n. """ - if k > Int(self.n): + if k > self.n: return -inf[DType.float64]() if self.p == 0.0: return 0.0 if k == 0 else -inf[DType.float64]() if self.p == 1.0: - return 0.0 if k == Int(self.n) else -inf[DType.float64]() + return 0.0 if k == self.n else -inf[DType.float64]() var kf = Float64(k) var nf = Float64(self.n) @@ -91,13 +91,13 @@ struct Binomial(DiscretelyDistributed): Returns: CDF value at *k*. Returns 1.0 for k >= n. """ - if k >= Int(self.n): + if k >= self.n: return 1.0 if self.p == 0.0: return 1.0 if self.p == 1.0: - return 0.0 if k < Int(self.n) else 1.0 + return 0.0 if k < self.n else 1.0 var nf = Float64(self.n) var q = 1.0 - self.p @@ -145,8 +145,6 @@ struct Binomial(DiscretelyDistributed): Returns: Log-survival function value at *k*. """ - if self.p < 0.0 or self.p > 1.0: - return nan[DType.float64]() return log1p(-self.cdf(k)) fn ppf(self, q: Float64) -> Int: @@ -161,7 +159,7 @@ struct Binomial(DiscretelyDistributed): if q == 0.0: return 0 if q == 1.0: - return Int(self.n) + return self.n var cumulative: Float64 = 0.0 for k in range(self.n + 1): @@ -169,7 +167,7 @@ struct Binomial(DiscretelyDistributed): if cumulative >= q: return Int(k) - return Int(self.n) + return self.n fn isf(self, q: Float64) -> Int: """Inverse survival function (inverse SF). @@ -181,7 +179,7 @@ struct Binomial(DiscretelyDistributed): Smallest integer k such that SF(k) ≤ q. Returns n for q=0, 0 for q=1. """ if q == 0.0: - return Int(self.n) + return self.n if q == 1.0: return 0 @@ -191,11 +189,11 @@ struct Binomial(DiscretelyDistributed): if 1.0 - cumulative <= q: return Int(k) - return Int(self.n) + return self.n # --- Summary statistics -------------------------------------------------- fn median(self) -> UInt: - """Median of the distribution. + """Median of the distribution: floor(n * p + 0.5). Returns: The median of the distribution. @@ -220,7 +218,7 @@ struct Binomial(DiscretelyDistributed): return np * (1.0 - self.p) fn std(self) -> Float64: - """Distribution standard deviation. + """Distribution standard deviation: sqrt(n * p * (1 - p)). Returns: The standard deviation of the distribution. diff --git a/tests/test_distributions.mojo b/tests/test_distributions.mojo index 4e404fa..dd00eb4 100644 --- a/tests/test_distributions.mojo +++ b/tests/test_distributions.mojo @@ -567,7 +567,7 @@ fn test_binomial_scipy() raises: print("test_binomial_scipy skipped (scipy not available)") return - var n: UInt = 10 + var n: Int = 10 var p = 0.3 var b = Binomial(n, p) var ks: List[Int] = [0, 1, 2, 5, 10] From 334493201365955c3eda22b828ee5288937bf7c5 Mon Sep 17 00:00:00 2001 From: ZHU Yuhao Date: Sat, 7 Mar 2026 10:58:38 +0100 Subject: [PATCH 8/8] Add checks when k < 0 for logpmf and cdf --- src/stamojo/distributions/binomial.mojo | 14 ++++++++------ tests/test_distributions.mojo | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/stamojo/distributions/binomial.mojo b/src/stamojo/distributions/binomial.mojo index 0664582..66122fc 100644 --- a/src/stamojo/distributions/binomial.mojo +++ b/src/stamojo/distributions/binomial.mojo @@ -68,7 +68,7 @@ struct Binomial(DiscretelyDistributed): Returns: Returns -∞ if k > n. """ - if k > self.n: + if k < 0 or k > self.n: return -inf[DType.float64]() if self.p == 0.0: @@ -91,6 +91,8 @@ struct Binomial(DiscretelyDistributed): Returns: CDF value at *k*. Returns 1.0 for k >= n. """ + if k < 0: + return 0.0 if k >= self.n: return 1.0 @@ -163,9 +165,9 @@ struct Binomial(DiscretelyDistributed): var cumulative: Float64 = 0.0 for k in range(self.n + 1): - cumulative += self.pmf(Int(k)) + cumulative += self.pmf(k) if cumulative >= q: - return Int(k) + return k return self.n @@ -185,9 +187,9 @@ struct Binomial(DiscretelyDistributed): var cumulative = 0.0 for k in range(self.n + 1): - cumulative += self.pmf(Int(k)) + cumulative += self.pmf(k) if 1.0 - cumulative <= q: - return Int(k) + return k return self.n @@ -231,7 +233,7 @@ struct Binomial(DiscretelyDistributed): # ===----------------------------------------------------------------------=== # -fn _log_binomial_coefficient(n: UInt, k: Int) -> Float64: +fn _log_binomial_coefficient(n: Int, k: Int) -> Float64: """Log of the binomial coefficient C(n, k). Args: diff --git a/tests/test_distributions.mojo b/tests/test_distributions.mojo index dd00eb4..b0898b1 100644 --- a/tests/test_distributions.mojo +++ b/tests/test_distributions.mojo @@ -4,7 +4,7 @@ # ===----------------------------------------------------------------------=== # """Tests for the distributions subpackage. -Covers Normal, Student's t, Chi-squared, F, and Exponential distributions. +Covers Normal, Student's t, Chi-squared, F, Exponential, and Binomial distributions. Each distribution is tested for: - Known analytical values - CDF/PPF round-trip consistency