Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pooltool/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
np.set_printoptions(precision=16, suppress=True)

EPS = np.finfo(float).eps * 100
EPS_SPACE = 1e-9

MIN_DIST = 1e-6
"""The minimum distance between balls."""

# Ball states
stationary: int = 0
"""The stationary motion state label

A ball with this motion state is both motionless and not in a pocket.
"""

spinning: int = 1
"""The spinning motion state label

Expand Down
15 changes: 2 additions & 13 deletions pooltool/evolution/event_based/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _get_collision_events_from_cache(
class SimulationSnapshot:
step_number: int
system: System
selected_event: Event
next_event: Event
collision_cache: CollisionCache
transition_cache: TransitionCache
engine: PhysicsEngine
Expand Down Expand Up @@ -227,7 +227,7 @@ def simulate_with_snapshots(
snapshot = SimulationSnapshot(
step_number=step,
system=system_pre_evolve,
selected_event=event,
next_event=event,
collision_cache=collision_cache_snapshot,
transition_cache=transition_cache_snapshot,
engine=engine,
Expand All @@ -240,14 +240,3 @@ def simulate_with_snapshots(
step += 1

return sim.shot, snapshot_sequence


if __name__ == "__main__":
from pathlib import Path

import pooltool as pt

output = Path("test.json")
simulate_with_snapshots(pt.System.example(), output)
seq = SimulationSnapshotSequence.load(output)
system = pt.simulate(pt.System.example())
129 changes: 94 additions & 35 deletions pooltool/evolution/event_based/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,61 @@ def _system_has_energy(system: System) -> bool:
)


def get_event_priority(event: Event, shot: System) -> tuple[int, float]:
"""Compute priority for an event to resolve ties among simultaneous events.

Returns a tuple (tier, energy) where:
- Lower tier = higher priority
- Higher energy = higher priority within the same tier

Priority tiers:
- Tier 1: STICK_BALL (always first)
- Tier 2: Transitions and BALL_POCKET (can resolve without affecting others)
- Tier 3: BALL_BALL and ball-cushion collisions

Args:
event: The event to compute priority for.
shot: The system state at the time the event was detected.

Returns:
A tuple of (tier, energy) for sorting.
"""
event_type = event.event_type

if event_type == EventType.NONE:
return (99, 0.0)

if event_type == EventType.STICK_BALL:
return (1, shot.cue.V0**2)

if event_type == EventType.BALL_POCKET:
ball_id = event.ids[0]
ball = shot.balls[ball_id]
energy = ptmath.get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m)
return (2, energy)

if event_type.is_transition():
ball_id = event.ids[0]
ball = shot.balls[ball_id]
energy = ptmath.get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m)
return (2, energy)

if event_type == EventType.BALL_BALL:
ball1_id, ball2_id = event.ids
v1 = shot.balls[ball1_id].state.rvw[1]
v2 = shot.balls[ball2_id].state.rvw[1]
energy = ptmath.squared_norm3d(v1 - v2)
return (3, energy)

if event_type in (EventType.BALL_LINEAR_CUSHION, EventType.BALL_CIRCULAR_CUSHION):
ball_id = event.ids[0]
ball = shot.balls[ball_id]
energy = ptmath.get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m)
return (3, energy)

return (99, 0.0)


@attrs.define
class _SimulationState:
shot: System
Expand Down Expand Up @@ -97,6 +152,7 @@ def step(self) -> Event:

if self.max_events > 0 and self.num_events > self.max_events:
self.shot.stop_balls()
self.shot._update_history(null_event(time=self.shot.t))
self.done = True

self.num_events += 1
Expand Down Expand Up @@ -252,50 +308,52 @@ def get_next_event(
if collision_cache is None:
collision_cache = CollisionCache.create()

# Start by assuming next event doesn't happen
event = null_event(time=np.inf)
# Collect all candidate events from each detection function.
candidates: list[Event] = []

# Stick-ball collisions only occur at t=0 (shot initiation), so we skip this
# check after the first timestep as an optimization. Other collision types are
# always checked because they can occur at any time during simulation. Note: even
# at t=0, we still call the remaining detection functions to fully populate the
# collision cache, which is needed by debug/introspection tools.
if shot.t == 0:
stick_ball_event = get_next_stick_ball_collision(
shot, collision_cache=collision_cache
candidates.append(
get_next_stick_ball_collision(shot, collision_cache=collision_cache)
)
if stick_ball_event.time < event.time:
event = stick_ball_event

transition_event = transition_cache.get_next()
if transition_event.time < event.time:
event = transition_event

ball_ball_event = get_next_ball_ball_collision(
shot, collision_cache=collision_cache
candidates.append(transition_cache.get_next())
candidates.append(
get_next_ball_ball_collision(shot, collision_cache=collision_cache)
)
if ball_ball_event.time < event.time:
event = ball_ball_event

ball_circular_cushion_event = get_next_ball_circular_cushion_event(
shot, collision_cache=collision_cache
candidates.append(
get_next_ball_circular_cushion_event(shot, collision_cache=collision_cache)
)
if ball_circular_cushion_event.time < event.time:
event = ball_circular_cushion_event

ball_linear_cushion_event = get_next_ball_linear_cushion_collision(
shot, collision_cache=collision_cache
candidates.append(
get_next_ball_linear_cushion_collision(shot, collision_cache=collision_cache)
)
if ball_linear_cushion_event.time < event.time:
event = ball_linear_cushion_event

ball_pocket_event = get_next_ball_pocket_collision(
shot, collision_cache=collision_cache
candidates.append(
get_next_ball_pocket_collision(shot, collision_cache=collision_cache)
)
if ball_pocket_event.time < event.time:
event = ball_pocket_event

return event
# Find the earliest time among all candidates.
min_time = min(event.time for event in candidates)

if min_time == np.inf:
return null_event(time=np.inf)

# Filter to only events occurring at the earliest time.
simultaneous = [e for e in candidates if e.time == min_time]

if len(simultaneous) == 1:
return simultaneous[0]

# When multiple events occur at the same time, select by priority tier, then by
# energy within the tier (higher energy first).
def sort_key(e: Event) -> tuple[int, float]:
tier, energy = get_event_priority(e, shot)
return (tier, -energy)

return min(simultaneous, key=sort_key)


def get_next_stick_ball_collision(
Expand Down Expand Up @@ -352,12 +410,13 @@ def get_next_ball_ball_collision(
and ball2_state.s in const.nontranslating
):
cache[ball_pair] = np.inf
elif (
ptmath.norm3d(ball1_state.rvw[0] - ball2_state.rvw[0])
< ball1_params.R + ball2_params.R
elif ptmath.is_overlapping(
ball1_state.rvw,
ball2_state.rvw,
ball1_params.R,
ball2_params.R,
):
# If balls are intersecting, avoid internal collisions
cache[ball_pair] = np.inf
cache[ball_pair] = shot.t
else:
dtau_E = solve.ball_ball_collision_time(
rvw1=ball1_state.rvw,
Expand Down
28 changes: 15 additions & 13 deletions pooltool/evolution/event_based/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,11 @@ def ball_ball_collision_coeffs(
Bx, By = b2x - b1x, b2y - b1y
Cx, Cy = c2x - c1x, c2y - c1y

a = Ax**2 + Ay**2
a = Ax * Ax + Ay * Ay
b = 2 * Ax * Bx + 2 * Ay * By
c = Bx**2 + 2 * Ax * Cx + 2 * Ay * Cy + By**2
c = Bx * Bx + 2 * Ax * Cx + 2 * Ay * Cy + By * By
d = 2 * Bx * Cx + 2 * By * Cy
e = Cx**2 + Cy**2 - 4 * R**2
e = Cx * Cx + Cy * Cy - 4 * R * R

return a, b, c, d, e

Expand Down Expand Up @@ -237,16 +237,16 @@ def ball_linear_cushion_collision_time(
B = lx * bx + ly * by

if direction == 0:
C = l0 + lx * cx + ly * cy + R * np.sqrt(lx**2 + ly**2)
C = l0 + lx * cx + ly * cy + R * np.sqrt(lx * lx + ly * ly)
root1, root2 = ptmath.roots.quadratic.solve(A, B, C)
roots = [root1, root2]
elif direction == 1:
C = l0 + lx * cx + ly * cy - R * np.sqrt(lx**2 + ly**2)
C = l0 + lx * cx + ly * cy - R * np.sqrt(lx * lx + ly * ly)
root1, root2 = ptmath.roots.quadratic.solve(A, B, C)
roots = [root1, root2]
else:
C1 = l0 + lx * cx + ly * cy + R * np.sqrt(lx**2 + ly**2)
C2 = l0 + lx * cx + ly * cy - R * np.sqrt(lx**2 + ly**2)
C1 = l0 + lx * cx + ly * cy + R * np.sqrt(lx * lx + ly * ly)
C2 = l0 + lx * cx + ly * cy - R * np.sqrt(lx * lx + ly * ly)
root1, root2 = ptmath.roots.quadratic.solve(A, B, C1)
root3, root4 = ptmath.roots.quadratic.solve(A, B, C2)
roots = [root1, root2, root3, root4]
Expand Down Expand Up @@ -308,11 +308,13 @@ def ball_circular_cushion_collision_coeffs(
bx, by = v * cos_phi, v * sin_phi
cx, cy = rvw[0, 0], rvw[0, 1]

A = 0.5 * (ax**2 + ay**2)
A = 0.5 * (ax * ax + ay * ay)
B = ax * bx + ay * by
C = ax * (cx - a) + ay * (cy - b) + 0.5 * (bx**2 + by**2)
C = ax * (cx - a) + ay * (cy - b) + 0.5 * (bx * bx + by * by)
D = bx * (cx - a) + by * (cy - b)
E = 0.5 * (a**2 + b**2 + cx**2 + cy**2 - (r + R) ** 2) - (cx * a + cy * b)
E = 0.5 * (a * a + b * b + cx * cx + cy * cy - (r + R) * (r + R)) - (
cx * a + cy * b
)

return A, B, C, D, E

Expand Down Expand Up @@ -381,11 +383,11 @@ def ball_pocket_collision_coeffs(
bx, by = v * cos_phi, v * sin_phi
cx, cy = rvw[0, 0], rvw[0, 1]

A = 0.5 * (ax**2 + ay**2)
A = 0.5 * (ax * ax + ay * ay)
B = ax * bx + ay * by
C = ax * (cx - a) + ay * (cy - b) + 0.5 * (bx**2 + by**2)
C = ax * (cx - a) + ay * (cy - b) + 0.5 * (bx * bx + by * by)
D = bx * (cx - a) + by * (cy - b)
E = 0.5 * (a**2 + b**2 + cx**2 + cy**2 - r**2) - (cx * a + cy * b)
E = 0.5 * (a * a + b * b + cx * cx + cy * cy - r * r) - (cx * a + cy * b)

return A, B, C, D, E

Expand Down
Loading