Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
* Uses prettier `tqdm` output that is now aware of Jupyter notebooks.
* `bayes.update` now supports `lognorm` and `gamma` distributions in addition to `norm` and `beta`.
* Allow `zero_inflated` to accept 0 or 1 as probabilities.
* **[New feature]** Added `die(sides, explode_on=None)` and `coin()` as distribution objects for dice and coins. These support all distribution operations like `~`, `@`, `+`, `-`, etc. The `die` distribution also supports "exploding dice" mechanics via the `explode_on` parameter.
* **[Breaking change]** Removed `roll_die` and `flip_coin` functions. Use `sq.die(sides) @ n` or `~sq.coin()` instead.

## v0.29 - latest release

Expand Down
34 changes: 24 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,21 +176,34 @@ a, b = sq.correlate((a, b), 0.5) # Correlate a and b with a correlation of 0.5
a, b = sq.correlate((a, b), [[1, 0.5], [0.5, 1]])
```

#### Example: Rolling a Die
#### Example: Rolling Dice and Flipping Coins

An example of how to use distributions to build tools:
Squigglepy has built-in support for dice and coins as distribution objects:

```Python
import squigglepy as sq

def roll_die(sides, n=1):
return sq.discrete(list(range(1, sides + 1))) @ n if sides > 0 else None

roll_die(sides=6, n=10)
# Roll a 6-sided die
sq.die(6) @ 10 # Roll 10 times
# [2, 6, 5, 2, 6, 2, 3, 1, 5, 2]

# Use the ~ operator for a single sample
~sq.die(6) # Roll once
# 4

# Flip a coin
~sq.coin()
# 'heads'

sq.coin() @ 5 # Flip 5 times
# ['heads', 'tails', 'heads', 'heads', 'tails']

# Exploding dice (roll again when you get certain values)
sq.die(6, explode_on=6) @ 5 # D6 that explodes on 6
# [3, 8, 2, 5, 1] # The 8 came from rolling 6 + 2
```

This is already included standard in the utils of this package. Use `sq.roll_die`.
Since `die` and `coin` are distribution objects, they work with all distribution operations:

### Bayesian inference

Expand Down Expand Up @@ -405,10 +418,11 @@ from squigglepy.numbers import K, M, B, T
from squigglepy import bayes

def define_event():
if sq.flip_coin() == 'heads': # Blue bag
return sq.roll_die(6)
if ~sq.coin() == 'heads': # Blue bag
return ~sq.die(6)
else: # Red bag
return sq.discrete([4, 6, 10, 20]) >> sq.roll_die
sides = ~sq.discrete([4, 6, 10, 20])
return ~sq.die(sides)


bayes.bayesnet(define_event,
Expand Down
1 change: 1 addition & 0 deletions squigglepy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
from .utils import * # noqa ignore=F405
from .rng import * # noqa ignore=F405
from .correlation import * # noqa ignore=F405
from .dice import * # noqa ignore=F405
from .version import __version__ # noqa ignore=F405
105 changes: 105 additions & 0 deletions squigglepy/dice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from .utils import is_dist
from .distributions import OperableDistribution


class Die(OperableDistribution):
"""
A distribution representing a die roll.

Supports exploding dice mechanics where additional dice are rolled
when certain values are rolled.
"""

def __init__(self, sides=None, explode_on=None):
super().__init__()

if is_dist(sides) or callable(sides):
from .samplers import sample

sides = sample(sides)
if sides is None:
raise ValueError("sides must be specified")
if not isinstance(sides, int):
raise ValueError("can only roll an integer number of sides")
if sides < 2:
raise ValueError("cannot roll less than a 2-sided die.")
if explode_on is not None:
if not isinstance(explode_on, list):
explode_on = [explode_on]
if len(explode_on) >= sides:
raise ValueError("cannot explode on every value")
for val in explode_on:
if not isinstance(val, int) or val < 1 or val > sides:
raise ValueError(f"explode_on values must be integers between 1 and {sides}")
self.sides = sides
self.explode_on = explode_on

def __str__(self):
explode_out = (
"" if self.explode_on is None else ", explodes on {}".format(str(self.explode_on))
)
out = "<Distribution> Die({}{})".format(self.sides, explode_out)
return out


def die(sides, explode_on=None):
"""
Create a distribution for a die.

Parameters
----------
sides : int
The number of sides of the die that is rolled.
explode_on : list or int or None
An additional die will be rolled if the initial die rolls any of these values.
Implements "exploding dice" mechanics. The exploding continues recursively
until a non-exploding value is rolled.

Returns
-------
Die
A distribution that models a die roll, returning values from 1 to sides.

Examples
--------
>>> die(6)
<Distribution> Die(6)
>>> die(6, explode_on=6) # D6 that explodes on 6
<Distribution> Die(6, explodes on [6])
"""
return Die(sides=sides, explode_on=explode_on)


class Coin(OperableDistribution):
"""
A distribution representing a coin flip.

Returns either "heads" or "tails" with equal probability.
"""

def __init__(self):
super().__init__()

def __str__(self):
out = "<Distribution> Coin"
return out


def coin():
"""
Create a distribution for a coin flip.

Returns
-------
Coin
A distribution that models a coin flip, returning either "heads" or "tails"
with equal probability.

Examples
--------
>>> coin()
<Distribution> Coin
>>> ~coin() # Sample a coin flip
'heads' # or 'tails'
"""
return Coin()
79 changes: 79 additions & 0 deletions squigglepy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
const,
)

from .dice import Coin, Die

_squigglepy_internal_sample_caches = {}


Expand Down Expand Up @@ -643,6 +645,77 @@ def geometric_sample(p, samples=1):
return _simplify(_get_rng().geometric(p, samples))


def die_sample(sides, explode_on=None, samples=1):
"""
Sample a random number from a die roll.

Parameters
----------
sides : int
The number of sides of the die.
explode_on : list or int or None
Values that trigger an additional die roll. Implements "exploding dice" mechanics.
samples : int
The number of samples to return.

Returns
-------
int or array
A random number (or array of numbers) from rolling the die.

Examples
--------
>>> set_seed(42)
>>> die_sample(6)
5
>>> die_sample(6, explode_on=6, samples=3)
array([4, 2, 8]) # The 8 came from rolling a 6 and then a 2
"""

def _single_roll():
total = 0
roll = int(_get_rng().integers(1, sides + 1))
total += roll
if explode_on is not None:
explode_list = explode_on if isinstance(explode_on, list) else [explode_on]
while roll in explode_list:
roll = int(_get_rng().integers(1, sides + 1))
total += roll
return total

if samples == 1:
return _single_roll()
else:
return np.array([_single_roll() for _ in range(samples)])


def coin_sample(samples=1):
"""
Sample a coin flip.

Parameters
----------
samples : int
The number of samples to return.

Returns
-------
str or list
"heads" or "tails" for a single sample, or a list of results for multiple samples.

Examples
--------
>>> set_seed(42)
>>> coin_sample()
'heads'
"""
if samples == 1:
return "heads" if _get_rng().integers(0, 2) == 1 else "tails"
else:
results = _get_rng().integers(0, 2, samples)
return ["heads" if r == 1 else "tails" for r in results]


def _mixture_sample_for_large_n(
values,
weights=None,
Expand Down Expand Up @@ -1133,6 +1206,12 @@ def run_dist(dist, pbar=None, tick=1):
elif isinstance(dist, GeometricDistribution):
samples = geometric_sample(p=dist.p, samples=n)

elif isinstance(dist, Die):
samples = die_sample(sides=dist.sides, explode_on=dist.explode_on, samples=n)

elif isinstance(dist, Coin):
samples = coin_sample(samples=n)

elif isinstance(dist, ComplexDistribution):
if dist.right is None:
samples = dist.fn(sample(dist.left, n=n, verbose=verbose))
Expand Down
66 changes: 0 additions & 66 deletions squigglepy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,72 +888,6 @@ def doubling_time_to_growth_rate(doubling_time):
return math.exp(math.log(2) / doubling_time) - 1


def roll_die(sides, n=1):
"""
Roll a die.

Parameters
----------
sides : int
The number of sides of the die that is rolled.
n : int
The number of dice to be rolled.

Returns
-------
int or list
Returns the value of each die roll.

Examples
--------
>>> set_seed(42)
>>> roll_die(6)
5
"""
if is_dist(sides) or callable(sides):
from .samplers import sample

sides = sample(sides)
if not isinstance(n, int):
raise ValueError("can only roll an integer number of times")
elif sides < 2:
raise ValueError("cannot roll less than a 2-sided die.")
elif not isinstance(sides, int):
raise ValueError("can only roll an integer number of sides")
else:
from .samplers import sample
from .distributions import discrete

return sample(discrete(list(range(1, sides + 1))), n=n) if sides > 0 else None


def flip_coin(n=1):
"""
Flip a coin.

Parameters
----------
n : int
The number of coins to be flipped.

Returns
-------
str or list
Returns the value of each coin flip, as either "heads" or "tails"

Examples
--------
>>> set_seed(42)
>>> flip_coin()
'heads'
"""
rolls = roll_die(2, n=n)
if isinstance(rolls, int):
rolls = [rolls]
flips = ["heads" if d == 2 else "tails" for d in rolls]
return flips[0] if len(flips) == 1 else flips


def kelly(
my_price, market_price, deference=0, bankroll=1, resolve_date=None, current=0, error=True
):
Expand Down
4 changes: 2 additions & 2 deletions tests/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,12 @@ def monte_hall_event():


def coins_and_dice():
flip = sq.flip_coin()
flip = ~sq.coin()
if flip == "heads":
dice_sides = 6
else:
dice_sides = ~sq.discrete([4, 6, 10, 20])
return sq.roll_die(dice_sides)
return ~sq.die(dice_sides)


def model():
Expand Down
Loading