diff --git a/CHANGES.md b/CHANGES.md index 9e9f184..ce357b1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -4,6 +4,8 @@ * Uses prettier `tqdm` output that is now aware of Jupyter notebooks. * `bayes.update` now supports `lognorm` and `gamma` distributions in addition to `norm` and `beta`. * Allow `zero_inflated` to accept 0 or 1 as probabilities. +* **[New feature]** Added `die(sides, explode_on=None)` and `coin()` as distribution objects for dice and coins. These support all distribution operations like `~`, `@`, `+`, `-`, etc. The `die` distribution also supports "exploding dice" mechanics via the `explode_on` parameter. +* **[Breaking change]** Removed `roll_die` and `flip_coin` functions. Use `sq.die(sides) @ n` or `~sq.coin()` instead. ## v0.29 - latest release diff --git a/README.md b/README.md index 600f420..94f5563 100644 --- a/README.md +++ b/README.md @@ -176,21 +176,34 @@ a, b = sq.correlate((a, b), 0.5) # Correlate a and b with a correlation of 0.5 a, b = sq.correlate((a, b), [[1, 0.5], [0.5, 1]]) ``` -#### Example: Rolling a Die +#### Example: Rolling Dice and Flipping Coins -An example of how to use distributions to build tools: +Squigglepy has built-in support for dice and coins as distribution objects: ```Python import squigglepy as sq -def roll_die(sides, n=1): - return sq.discrete(list(range(1, sides + 1))) @ n if sides > 0 else None - -roll_die(sides=6, n=10) +# Roll a 6-sided die +sq.die(6) @ 10 # Roll 10 times # [2, 6, 5, 2, 6, 2, 3, 1, 5, 2] + +# Use the ~ operator for a single sample +~sq.die(6) # Roll once +# 4 + +# Flip a coin +~sq.coin() +# 'heads' + +sq.coin() @ 5 # Flip 5 times +# ['heads', 'tails', 'heads', 'heads', 'tails'] + +# Exploding dice (roll again when you get certain values) +sq.die(6, explode_on=6) @ 5 # D6 that explodes on 6 +# [3, 8, 2, 5, 1] # The 8 came from rolling 6 + 2 ``` -This is already included standard in the utils of this package. Use `sq.roll_die`. +Since `die` and `coin` are distribution objects, they work with all distribution operations: ### Bayesian inference @@ -405,10 +418,11 @@ from squigglepy.numbers import K, M, B, T from squigglepy import bayes def define_event(): - if sq.flip_coin() == 'heads': # Blue bag - return sq.roll_die(6) + if ~sq.coin() == 'heads': # Blue bag + return ~sq.die(6) else: # Red bag - return sq.discrete([4, 6, 10, 20]) >> sq.roll_die + sides = ~sq.discrete([4, 6, 10, 20]) + return ~sq.die(sides) bayes.bayesnet(define_event, diff --git a/squigglepy/__init__.py b/squigglepy/__init__.py index 9ac7350..bcf9445 100644 --- a/squigglepy/__init__.py +++ b/squigglepy/__init__.py @@ -4,4 +4,5 @@ from .utils import * # noqa ignore=F405 from .rng import * # noqa ignore=F405 from .correlation import * # noqa ignore=F405 +from .dice import * # noqa ignore=F405 from .version import __version__ # noqa ignore=F405 diff --git a/squigglepy/dice.py b/squigglepy/dice.py new file mode 100644 index 0000000..ee80bc0 --- /dev/null +++ b/squigglepy/dice.py @@ -0,0 +1,105 @@ +from .utils import is_dist +from .distributions import OperableDistribution + + +class Die(OperableDistribution): + """ + A distribution representing a die roll. + + Supports exploding dice mechanics where additional dice are rolled + when certain values are rolled. + """ + + def __init__(self, sides=None, explode_on=None): + super().__init__() + + if is_dist(sides) or callable(sides): + from .samplers import sample + + sides = sample(sides) + if sides is None: + raise ValueError("sides must be specified") + if not isinstance(sides, int): + raise ValueError("can only roll an integer number of sides") + if sides < 2: + raise ValueError("cannot roll less than a 2-sided die.") + if explode_on is not None: + if not isinstance(explode_on, list): + explode_on = [explode_on] + if len(explode_on) >= sides: + raise ValueError("cannot explode on every value") + for val in explode_on: + if not isinstance(val, int) or val < 1 or val > sides: + raise ValueError(f"explode_on values must be integers between 1 and {sides}") + self.sides = sides + self.explode_on = explode_on + + def __str__(self): + explode_out = ( + "" if self.explode_on is None else ", explodes on {}".format(str(self.explode_on)) + ) + out = " Die({}{})".format(self.sides, explode_out) + return out + + +def die(sides, explode_on=None): + """ + Create a distribution for a die. + + Parameters + ---------- + sides : int + The number of sides of the die that is rolled. + explode_on : list or int or None + An additional die will be rolled if the initial die rolls any of these values. + Implements "exploding dice" mechanics. The exploding continues recursively + until a non-exploding value is rolled. + + Returns + ------- + Die + A distribution that models a die roll, returning values from 1 to sides. + + Examples + -------- + >>> die(6) + Die(6) + >>> die(6, explode_on=6) # D6 that explodes on 6 + Die(6, explodes on [6]) + """ + return Die(sides=sides, explode_on=explode_on) + + +class Coin(OperableDistribution): + """ + A distribution representing a coin flip. + + Returns either "heads" or "tails" with equal probability. + """ + + def __init__(self): + super().__init__() + + def __str__(self): + out = " Coin" + return out + + +def coin(): + """ + Create a distribution for a coin flip. + + Returns + ------- + Coin + A distribution that models a coin flip, returning either "heads" or "tails" + with equal probability. + + Examples + -------- + >>> coin() + Coin + >>> ~coin() # Sample a coin flip + 'heads' # or 'tails' + """ + return Coin() diff --git a/squigglepy/samplers.py b/squigglepy/samplers.py index 8edc48e..953eeae 100644 --- a/squigglepy/samplers.py +++ b/squigglepy/samplers.py @@ -49,6 +49,8 @@ const, ) +from .dice import Coin, Die + _squigglepy_internal_sample_caches = {} @@ -643,6 +645,77 @@ def geometric_sample(p, samples=1): return _simplify(_get_rng().geometric(p, samples)) +def die_sample(sides, explode_on=None, samples=1): + """ + Sample a random number from a die roll. + + Parameters + ---------- + sides : int + The number of sides of the die. + explode_on : list or int or None + Values that trigger an additional die roll. Implements "exploding dice" mechanics. + samples : int + The number of samples to return. + + Returns + ------- + int or array + A random number (or array of numbers) from rolling the die. + + Examples + -------- + >>> set_seed(42) + >>> die_sample(6) + 5 + >>> die_sample(6, explode_on=6, samples=3) + array([4, 2, 8]) # The 8 came from rolling a 6 and then a 2 + """ + + def _single_roll(): + total = 0 + roll = int(_get_rng().integers(1, sides + 1)) + total += roll + if explode_on is not None: + explode_list = explode_on if isinstance(explode_on, list) else [explode_on] + while roll in explode_list: + roll = int(_get_rng().integers(1, sides + 1)) + total += roll + return total + + if samples == 1: + return _single_roll() + else: + return np.array([_single_roll() for _ in range(samples)]) + + +def coin_sample(samples=1): + """ + Sample a coin flip. + + Parameters + ---------- + samples : int + The number of samples to return. + + Returns + ------- + str or list + "heads" or "tails" for a single sample, or a list of results for multiple samples. + + Examples + -------- + >>> set_seed(42) + >>> coin_sample() + 'heads' + """ + if samples == 1: + return "heads" if _get_rng().integers(0, 2) == 1 else "tails" + else: + results = _get_rng().integers(0, 2, samples) + return ["heads" if r == 1 else "tails" for r in results] + + def _mixture_sample_for_large_n( values, weights=None, @@ -1133,6 +1206,12 @@ def run_dist(dist, pbar=None, tick=1): elif isinstance(dist, GeometricDistribution): samples = geometric_sample(p=dist.p, samples=n) + elif isinstance(dist, Die): + samples = die_sample(sides=dist.sides, explode_on=dist.explode_on, samples=n) + + elif isinstance(dist, Coin): + samples = coin_sample(samples=n) + elif isinstance(dist, ComplexDistribution): if dist.right is None: samples = dist.fn(sample(dist.left, n=n, verbose=verbose)) diff --git a/squigglepy/utils.py b/squigglepy/utils.py index 74b9cec..09e1db4 100644 --- a/squigglepy/utils.py +++ b/squigglepy/utils.py @@ -888,72 +888,6 @@ def doubling_time_to_growth_rate(doubling_time): return math.exp(math.log(2) / doubling_time) - 1 -def roll_die(sides, n=1): - """ - Roll a die. - - Parameters - ---------- - sides : int - The number of sides of the die that is rolled. - n : int - The number of dice to be rolled. - - Returns - ------- - int or list - Returns the value of each die roll. - - Examples - -------- - >>> set_seed(42) - >>> roll_die(6) - 5 - """ - if is_dist(sides) or callable(sides): - from .samplers import sample - - sides = sample(sides) - if not isinstance(n, int): - raise ValueError("can only roll an integer number of times") - elif sides < 2: - raise ValueError("cannot roll less than a 2-sided die.") - elif not isinstance(sides, int): - raise ValueError("can only roll an integer number of sides") - else: - from .samplers import sample - from .distributions import discrete - - return sample(discrete(list(range(1, sides + 1))), n=n) if sides > 0 else None - - -def flip_coin(n=1): - """ - Flip a coin. - - Parameters - ---------- - n : int - The number of coins to be flipped. - - Returns - ------- - str or list - Returns the value of each coin flip, as either "heads" or "tails" - - Examples - -------- - >>> set_seed(42) - >>> flip_coin() - 'heads' - """ - rolls = roll_die(2, n=n) - if isinstance(rolls, int): - rolls = [rolls] - flips = ["heads" if d == 2 else "tails" for d in rolls] - return flips[0] if len(flips) == 1 else flips - - def kelly( my_price, market_price, deference=0, bankroll=1, resolve_date=None, current=0, error=True ): diff --git a/tests/integration.py b/tests/integration.py index f94dd94..b240ef6 100644 --- a/tests/integration.py +++ b/tests/integration.py @@ -157,12 +157,12 @@ def monte_hall_event(): def coins_and_dice(): - flip = sq.flip_coin() + flip = ~sq.coin() if flip == "heads": dice_sides = 6 else: dice_sides = ~sq.discrete([4, 6, 10, 20]) - return sq.roll_die(dice_sides) + return ~sq.die(dice_sides) def model(): diff --git a/tests/test_dice.py b/tests/test_dice.py new file mode 100644 index 0000000..aa3a2bd --- /dev/null +++ b/tests/test_dice.py @@ -0,0 +1,159 @@ +import pytest + +from ..squigglepy.rng import set_seed +from ..squigglepy.dice import die, coin, Die, Coin +from ..squigglepy.distributions import OperableDistribution +from ..squigglepy.samplers import sample + + +def test_die_basic(): + """Test basic die roll.""" + set_seed(42) + result = ~die(6) + assert result == 1 + + +def test_die_multiple_samples(): + """Test rolling multiple dice.""" + set_seed(42) + result = die(6) @ 10 + assert len(result) == 10 + for r in result: + assert 1 <= r <= 6 + + +def test_die_different_sides(): + """Test dice with different numbers of sides.""" + set_seed(42) + for sides in [4, 8, 10, 12, 20, 100]: + result = die(sides) @ 100 + assert min(result) >= 1 + assert max(result) <= sides + + +def test_die_exploding(): + """Test exploding dice mechanics.""" + set_seed(42) + # With exploding dice on 6, we can get values > 6 + result = die(6, explode_on=6) @ 1000 + # At least some results should be > 6 (from explosions) + assert max(result) > 6 + # All results should be >= 1 + assert min(result) >= 1 + + +def test_die_exploding_multiple_values(): + """Test exploding on multiple values.""" + set_seed(42) + result = die(6, explode_on=[5, 6]) @ 1000 + # Should have higher chance of explosions + assert max(result) > 6 + + +def test_die_str(): + """Test die string representation.""" + d = die(6) + assert " Die(6)" == str(d) + + d_exploding = die(6, explode_on=6) + assert " Die(6, explodes on [6])" == str(d_exploding) + + d_multi_explode = die(6, explode_on=[5, 6]) + assert " Die(6, explodes on [5, 6])" == str(d_multi_explode) + + +def test_die_invalid_sides(): + """Test that invalid sides raise errors.""" + with pytest.raises(ValueError) as excinfo: + die(1) + assert "cannot roll less than a 2-sided die" in str(excinfo.value) + + +def test_die_invalid_sides_type(): + """Test that non-integer sides raise errors.""" + with pytest.raises(ValueError) as excinfo: + die(2.5) + assert "can only roll an integer number of sides" in str(excinfo.value) + + +def test_die_explode_on_all_values(): + """Test that exploding on all values raises error.""" + with pytest.raises(ValueError) as excinfo: + die(6, explode_on=[1, 2, 3, 4, 5, 6]) + assert "cannot explode on every value" in str(excinfo.value) + + +def test_die_explode_invalid_value(): + """Test that invalid explode_on values raise errors.""" + with pytest.raises(ValueError) as excinfo: + die(6, explode_on=7) + assert "explode_on values must be integers between 1 and 6" in str(excinfo.value) + + +def test_coin_basic(): + """Test basic coin flip.""" + set_seed(42) + result = ~coin() + assert result == "tails" + + +def test_coin_multiple_samples(): + """Test flipping multiple coins.""" + set_seed(42) + result = coin() @ 10 + assert len(result) == 10 + for r in result: + assert r in ["heads", "tails"] + + +def test_coin_distribution(): + """Test that coin flips are approximately 50/50.""" + set_seed(42) + result = coin() @ 10000 + heads_count = sum(1 for r in result if r == "heads") + # Should be close to 50% + assert 4500 <= heads_count <= 5500 + + +def test_coin_str(): + """Test coin string representation.""" + c = coin() + assert " Coin" == str(c) + + +def test_die_is_operable(): + """Test that die can be used in operations.""" + set_seed(42) + d = die(6) + # Test that it's an OperableDistribution + assert isinstance(d, OperableDistribution) + assert isinstance(d, Die) + # Test basic operations + double_die = d * 2 + result = ~double_die + assert result == 2 + + +def test_coin_is_operable(): + """Test that coin is an OperableDistribution.""" + c = coin() + assert isinstance(c, OperableDistribution) + assert isinstance(c, Coin) + + +def test_die_sample_function(): + """Test using die with sample function.""" + set_seed(42) + result = sample(die(6), n=5) + assert len(result) == 5 + for r in result: + assert 1 <= r <= 6 + + +def test_coin_sample_function(): + """Test using coin with sample function.""" + set_seed(42) + result = sample(coin(), n=5) + assert len(result) == 5 + for r in result: + assert r in ["heads", "tails"] diff --git a/tests/test_utils.py b/tests/test_utils.py index 1ffdb99..c0b6143 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -21,8 +21,6 @@ laplace, growth_rate_to_doubling_time, doubling_time_to_growth_rate, - roll_die, - flip_coin, kelly, full_kelly, half_kelly, @@ -35,7 +33,7 @@ bucket_percentages, ) from ..squigglepy.rng import set_seed -from ..squigglepy.distributions import bernoulli, beta, norm, dist_round, const +from ..squigglepy.distributions import bernoulli, beta, const def test_process_weights_values_simple_case(): @@ -646,54 +644,6 @@ def test_doubling_time_to_growth_rate_dist(): assert round(doubling_time_to_growth_rate(const(12)) @ 1, 2) == 0.06 -def test_roll_die(): - set_seed(42) - assert roll_die(6) == 5 - - -def test_roll_die_different_sides(): - set_seed(42) - assert roll_die(4) == 4 - - -def test_roll_die_with_distribution(): - set_seed(42) - assert (norm(2, 6) >> dist_round >> roll_die) == 2 - - -def test_roll_one_sided_die(): - with pytest.raises(ValueError) as excinfo: - roll_die(1) - assert "cannot roll less than a 2-sided die" in str(excinfo.value) - - -def test_roll_nonint_die(): - with pytest.raises(ValueError) as excinfo: - roll_die(2.5) - assert "can only roll an integer number of sides" in str(excinfo.value) - - -def test_roll_nonint_n(): - with pytest.raises(ValueError) as excinfo: - roll_die(6, 2.5) - assert "can only roll an integer number of times" in str(excinfo.value) - - -def test_roll_five_die(): - set_seed(42) - assert list(roll_die(4, 4)) == [4, 2, 4, 3] - - -def test_flip_coin(): - set_seed(42) - assert flip_coin() == "heads" - - -def test_flip_five_coins(): - set_seed(42) - assert flip_coin(5) == ["heads", "tails", "heads", "heads", "tails"] - - def test_kelly_market_price_error(): for val in [0, 1, 2, -1]: with pytest.raises(ValueError) as execinfo: