diff --git a/pooltool/constants.py b/pooltool/constants.py index f180a69e..bbc3d852 100644 --- a/pooltool/constants.py +++ b/pooltool/constants.py @@ -12,7 +12,9 @@ np.set_printoptions(precision=16, suppress=True) EPS = np.finfo(float).eps * 100 -EPS_SPACE = 1e-9 + +MIN_DIST = 1e-6 +"""The minimum distance between balls.""" # Ball states stationary: int = 0 @@ -20,6 +22,7 @@ A ball with this motion state is both motionless and not in a pocket. """ + spinning: int = 1 """The spinning motion state label diff --git a/pooltool/evolution/event_based/introspection.py b/pooltool/evolution/event_based/introspection.py index 74809322..16d0d204 100644 --- a/pooltool/evolution/event_based/introspection.py +++ b/pooltool/evolution/event_based/introspection.py @@ -111,7 +111,7 @@ def _get_collision_events_from_cache( class SimulationSnapshot: step_number: int system: System - selected_event: Event + next_event: Event collision_cache: CollisionCache transition_cache: TransitionCache engine: PhysicsEngine @@ -227,7 +227,7 @@ def simulate_with_snapshots( snapshot = SimulationSnapshot( step_number=step, system=system_pre_evolve, - selected_event=event, + next_event=event, collision_cache=collision_cache_snapshot, transition_cache=transition_cache_snapshot, engine=engine, @@ -240,14 +240,3 @@ def simulate_with_snapshots( step += 1 return sim.shot, snapshot_sequence - - -if __name__ == "__main__": - from pathlib import Path - - import pooltool as pt - - output = Path("test.json") - simulate_with_snapshots(pt.System.example(), output) - seq = SimulationSnapshotSequence.load(output) - system = pt.simulate(pt.System.example()) diff --git a/pooltool/evolution/event_based/simulate.py b/pooltool/evolution/event_based/simulate.py index f632ee79..a20b9e63 100755 --- a/pooltool/evolution/event_based/simulate.py +++ b/pooltool/evolution/event_based/simulate.py @@ -50,6 +50,61 @@ def _system_has_energy(system: System) -> bool: ) +def get_event_priority(event: Event, shot: System) -> tuple[int, float]: + """Compute priority for an event to resolve ties among simultaneous events. + + Returns a tuple (tier, energy) where: + - Lower tier = higher priority + - Higher energy = higher priority within the same tier + + Priority tiers: + - Tier 1: STICK_BALL (always first) + - Tier 2: Transitions and BALL_POCKET (can resolve without affecting others) + - Tier 3: BALL_BALL and ball-cushion collisions + + Args: + event: The event to compute priority for. + shot: The system state at the time the event was detected. + + Returns: + A tuple of (tier, energy) for sorting. + """ + event_type = event.event_type + + if event_type == EventType.NONE: + return (99, 0.0) + + if event_type == EventType.STICK_BALL: + return (1, shot.cue.V0**2) + + if event_type == EventType.BALL_POCKET: + ball_id = event.ids[0] + ball = shot.balls[ball_id] + energy = ptmath.get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m) + return (2, energy) + + if event_type.is_transition(): + ball_id = event.ids[0] + ball = shot.balls[ball_id] + energy = ptmath.get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m) + return (2, energy) + + if event_type == EventType.BALL_BALL: + ball1_id, ball2_id = event.ids + v1 = shot.balls[ball1_id].state.rvw[1] + v2 = shot.balls[ball2_id].state.rvw[1] + energy = ptmath.squared_norm3d(v1 - v2) + return (3, energy) + + if event_type in (EventType.BALL_LINEAR_CUSHION, EventType.BALL_CIRCULAR_CUSHION): + ball_id = event.ids[0] + ball = shot.balls[ball_id] + energy = ptmath.get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m) + return (3, energy) + + return (99, 0.0) + + @attrs.define class _SimulationState: shot: System @@ -97,6 +152,7 @@ def step(self) -> Event: if self.max_events > 0 and self.num_events > self.max_events: self.shot.stop_balls() + self.shot._update_history(null_event(time=self.shot.t)) self.done = True self.num_events += 1 @@ -252,8 +308,8 @@ def get_next_event( if collision_cache is None: collision_cache = CollisionCache.create() - # Start by assuming next event doesn't happen - event = null_event(time=np.inf) + # Collect all candidate events from each detection function. + candidates: list[Event] = [] # Stick-ball collisions only occur at t=0 (shot initiation), so we skip this # check after the first timestep as an optimization. Other collision types are @@ -261,41 +317,43 @@ def get_next_event( # at t=0, we still call the remaining detection functions to fully populate the # collision cache, which is needed by debug/introspection tools. if shot.t == 0: - stick_ball_event = get_next_stick_ball_collision( - shot, collision_cache=collision_cache + candidates.append( + get_next_stick_ball_collision(shot, collision_cache=collision_cache) ) - if stick_ball_event.time < event.time: - event = stick_ball_event - - transition_event = transition_cache.get_next() - if transition_event.time < event.time: - event = transition_event - ball_ball_event = get_next_ball_ball_collision( - shot, collision_cache=collision_cache + candidates.append(transition_cache.get_next()) + candidates.append( + get_next_ball_ball_collision(shot, collision_cache=collision_cache) ) - if ball_ball_event.time < event.time: - event = ball_ball_event - - ball_circular_cushion_event = get_next_ball_circular_cushion_event( - shot, collision_cache=collision_cache + candidates.append( + get_next_ball_circular_cushion_event(shot, collision_cache=collision_cache) ) - if ball_circular_cushion_event.time < event.time: - event = ball_circular_cushion_event - - ball_linear_cushion_event = get_next_ball_linear_cushion_collision( - shot, collision_cache=collision_cache + candidates.append( + get_next_ball_linear_cushion_collision(shot, collision_cache=collision_cache) ) - if ball_linear_cushion_event.time < event.time: - event = ball_linear_cushion_event - - ball_pocket_event = get_next_ball_pocket_collision( - shot, collision_cache=collision_cache + candidates.append( + get_next_ball_pocket_collision(shot, collision_cache=collision_cache) ) - if ball_pocket_event.time < event.time: - event = ball_pocket_event - return event + # Find the earliest time among all candidates. + min_time = min(event.time for event in candidates) + + if min_time == np.inf: + return null_event(time=np.inf) + + # Filter to only events occurring at the earliest time. + simultaneous = [e for e in candidates if e.time == min_time] + + if len(simultaneous) == 1: + return simultaneous[0] + + # When multiple events occur at the same time, select by priority tier, then by + # energy within the tier (higher energy first). + def sort_key(e: Event) -> tuple[int, float]: + tier, energy = get_event_priority(e, shot) + return (tier, -energy) + + return min(simultaneous, key=sort_key) def get_next_stick_ball_collision( @@ -352,12 +410,13 @@ def get_next_ball_ball_collision( and ball2_state.s in const.nontranslating ): cache[ball_pair] = np.inf - elif ( - ptmath.norm3d(ball1_state.rvw[0] - ball2_state.rvw[0]) - < ball1_params.R + ball2_params.R + elif ptmath.is_overlapping( + ball1_state.rvw, + ball2_state.rvw, + ball1_params.R, + ball2_params.R, ): - # If balls are intersecting, avoid internal collisions - cache[ball_pair] = np.inf + cache[ball_pair] = shot.t else: dtau_E = solve.ball_ball_collision_time( rvw1=ball1_state.rvw, diff --git a/pooltool/evolution/event_based/solve.py b/pooltool/evolution/event_based/solve.py index 5ac90fea..c1a80ed1 100644 --- a/pooltool/evolution/event_based/solve.py +++ b/pooltool/evolution/event_based/solve.py @@ -154,11 +154,11 @@ def ball_ball_collision_coeffs( Bx, By = b2x - b1x, b2y - b1y Cx, Cy = c2x - c1x, c2y - c1y - a = Ax**2 + Ay**2 + a = Ax * Ax + Ay * Ay b = 2 * Ax * Bx + 2 * Ay * By - c = Bx**2 + 2 * Ax * Cx + 2 * Ay * Cy + By**2 + c = Bx * Bx + 2 * Ax * Cx + 2 * Ay * Cy + By * By d = 2 * Bx * Cx + 2 * By * Cy - e = Cx**2 + Cy**2 - 4 * R**2 + e = Cx * Cx + Cy * Cy - 4 * R * R return a, b, c, d, e @@ -237,16 +237,16 @@ def ball_linear_cushion_collision_time( B = lx * bx + ly * by if direction == 0: - C = l0 + lx * cx + ly * cy + R * np.sqrt(lx**2 + ly**2) + C = l0 + lx * cx + ly * cy + R * np.sqrt(lx * lx + ly * ly) root1, root2 = ptmath.roots.quadratic.solve(A, B, C) roots = [root1, root2] elif direction == 1: - C = l0 + lx * cx + ly * cy - R * np.sqrt(lx**2 + ly**2) + C = l0 + lx * cx + ly * cy - R * np.sqrt(lx * lx + ly * ly) root1, root2 = ptmath.roots.quadratic.solve(A, B, C) roots = [root1, root2] else: - C1 = l0 + lx * cx + ly * cy + R * np.sqrt(lx**2 + ly**2) - C2 = l0 + lx * cx + ly * cy - R * np.sqrt(lx**2 + ly**2) + C1 = l0 + lx * cx + ly * cy + R * np.sqrt(lx * lx + ly * ly) + C2 = l0 + lx * cx + ly * cy - R * np.sqrt(lx * lx + ly * ly) root1, root2 = ptmath.roots.quadratic.solve(A, B, C1) root3, root4 = ptmath.roots.quadratic.solve(A, B, C2) roots = [root1, root2, root3, root4] @@ -308,11 +308,13 @@ def ball_circular_cushion_collision_coeffs( bx, by = v * cos_phi, v * sin_phi cx, cy = rvw[0, 0], rvw[0, 1] - A = 0.5 * (ax**2 + ay**2) + A = 0.5 * (ax * ax + ay * ay) B = ax * bx + ay * by - C = ax * (cx - a) + ay * (cy - b) + 0.5 * (bx**2 + by**2) + C = ax * (cx - a) + ay * (cy - b) + 0.5 * (bx * bx + by * by) D = bx * (cx - a) + by * (cy - b) - E = 0.5 * (a**2 + b**2 + cx**2 + cy**2 - (r + R) ** 2) - (cx * a + cy * b) + E = 0.5 * (a * a + b * b + cx * cx + cy * cy - (r + R) * (r + R)) - ( + cx * a + cy * b + ) return A, B, C, D, E @@ -381,11 +383,11 @@ def ball_pocket_collision_coeffs( bx, by = v * cos_phi, v * sin_phi cx, cy = rvw[0, 0], rvw[0, 1] - A = 0.5 * (ax**2 + ay**2) + A = 0.5 * (ax * ax + ay * ay) B = ax * bx + ay * by - C = ax * (cx - a) + ay * (cy - b) + 0.5 * (bx**2 + by**2) + C = ax * (cx - a) + ay * (cy - b) + 0.5 * (bx * bx + by * by) D = bx * (cx - a) + by * (cy - b) - E = 0.5 * (a**2 + b**2 + cx**2 + cy**2 - r**2) - (cx * a + cy * b) + E = 0.5 * (a * a + b * b + cx * cx + cy * cy - r * r) - (cx * a + cy * b) return A, B, C, D, E diff --git a/pooltool/layouts.py b/pooltool/layouts.py index 55466deb..f08a0c70 100755 --- a/pooltool/layouts.py +++ b/pooltool/layouts.py @@ -52,48 +52,59 @@ def translation_map(cls) -> dict[Dir, tuple[float, float]]: } +Translation = Dir | float + + class Jump: @staticmethod - def LEFT(quantity: int = 1) -> list[Dir]: + def LEFT(quantity: int = 1) -> list[Translation]: return [Dir.LEFT] * quantity @staticmethod - def RIGHT(quantity: int = 1) -> list[Dir]: + def RIGHT(quantity: int = 1) -> list[Translation]: return [Dir.RIGHT] * quantity @staticmethod - def UP(quantity: int = 1) -> list[Dir]: + def UP(quantity: int = 1) -> list[Translation]: return [Dir.UP] * quantity @staticmethod - def DOWN(quantity: int = 1) -> list[Dir]: + def DOWN(quantity: int = 1) -> list[Translation]: return [Dir.DOWN] * quantity @staticmethod - def UPLEFT(quantity: int = 1) -> list[Dir]: + def UPLEFT(quantity: int = 1) -> list[Translation]: return [Dir.UPLEFT] * quantity @staticmethod - def UPRIGHT(quantity: int = 1) -> list[Dir]: + def UPRIGHT(quantity: int = 1) -> list[Translation]: return [Dir.UPRIGHT] * quantity @staticmethod - def DOWNRIGHT(quantity: int = 1) -> list[Dir]: + def DOWNRIGHT(quantity: int = 1) -> list[Translation]: return [Dir.DOWNRIGHT] * quantity @staticmethod - def DOWNLEFT(quantity: int = 1) -> list[Dir]: + def DOWNLEFT(quantity: int = 1) -> list[Translation]: return [Dir.DOWNLEFT] * quantity @staticmethod - def eval(translations: list[Dir], radius: float) -> tuple[float, float]: + def ANGLE(degrees: float, quantity: int = 1) -> list[Translation]: + radians = np.radians(degrees) + return [radians] * quantity + + @staticmethod + def eval(translations: list[Translation], radius: float) -> tuple[float, float]: mapping = Dir.translation_map assert isinstance(mapping, dict) - dx, dy = 0, 0 + dx, dy = 0.0, 0.0 - for direction in translations: - i, j = mapping[direction] + for translation in translations: + if isinstance(translation, Dir): + i, j = mapping[translation] + else: + i, j = 2 * np.cos(translation), 2 * np.sin(translation) dx += i * radius dy += j * radius @@ -106,7 +117,9 @@ class Pos: Attributes: loc: - A sequence of translations. + A sequence of translations. Each translation is either a Dir enum + for discrete directions, or a float representing an angle in radians + (0 = right, pi/2 = up). Use Jump.ANGLE(degrees) for convenience. relative_to: This defines what the translation is with respect to. This can either be another Pos, or a 2D coordinate, normalized by the table's @@ -114,7 +127,7 @@ class Pos: so (0.0, 0.0) is bottom-left and (1.0, 1.0) is top right. """ - loc: list[Dir] + loc: list[Translation] relative_to: Pos | tuple[float, float] @@ -130,7 +143,7 @@ class BallPos(Pos): ids: set[str] -JumpSequence = list[tuple[list[Dir], set[str]]] +JumpSequence = list[tuple[list[Translation], set[str]]] def ball_cluster_blueprint(seed: BallPos, jump_sequence: JumpSequence) -> list[BallPos]: @@ -153,10 +166,10 @@ def _get_ball_ids(positions: list[BallPos]) -> set[str]: return ids -def _get_anchor_translation(pos: Pos) -> tuple[tuple[float, float], list[Dir]]: +def _get_anchor_translation(pos: Pos) -> tuple[tuple[float, float], list[Translation]]: """Traverse the position's parent hierarchy until the anchor is found""" - translation_from_anchor: list[Dir] = [] + translation_from_anchor: list[Translation] = [] translation_from_anchor.extend(pos.loc) parent = pos.relative_to @@ -543,6 +556,7 @@ def get_rack( "Jump", "Pos", "BallPos", + "Translation", "ball_cluster_blueprint", "generate_layout", "get_rack", diff --git a/pooltool/physics/resolve/ball_ball/core.py b/pooltool/physics/resolve/ball_ball/core.py index c0d49ea7..0a90ebad 100644 --- a/pooltool/physics/resolve/ball_ball/core.py +++ b/pooltool/physics/resolve/ball_ball/core.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from typing import Protocol +import numpy as np + import pooltool.constants as const import pooltool.ptmath as ptmath from pooltool.objects.ball.datatypes import Ball @@ -25,21 +27,201 @@ def solve(self, ball1: Ball, ball2: Ball) -> tuple[Ball, Ball]: class CoreBallBallCollision(ABC): """Operations used by every ball-ball collision resolver""" + def _apply_fallback_positioning( + self, + ball1: Ball, + ball2: Ball, + r1: np.ndarray, + r2: np.ndarray, + spacer: float, + ) -> tuple[np.ndarray, np.ndarray]: + """Apply fallback positioning by moving balls along line of centers. + + This fallback strategy moves balls uniformly along the line of centers until + they're separated by the target distance (2*R + spacer). + """ + correction = 2 * ball1.params.R - ptmath.norm3d(r2 - r1) + spacer + r1_corrected = r1 - correction / 2 * ptmath.unit_vector(r2 - r1) + r2_corrected = r2 + correction / 2 * ptmath.unit_vector(r2 - r1) + return r1_corrected, r2_corrected + def make_kiss(self, ball1: Ball, ball2: Ball) -> tuple[Ball, Ball]: - """Translate the balls so they are (almost) touching + """Position balls at precise target separation before collision resolution. + + This method adjusts ball positions so they are separated by exactly 2*R + + spacer, where R is the ball radius and ``spacer`` is a small epsilon to prevent + ball intersection that occurs due to floating-point precision if an explicit + spacer is not added. + + The primary method solves a quadratic equation to find the time offset that + positions the balls at the target separation. Balls are moved along their + trajectories (position + velocity * time) to this configuration. Acceleration + terms are assumed negligible. + + If both balls are non-translating, or if the midpoint (collision point) shifts + by more than 5x the spacer (which can occur if balls are moving with nearly the + same velocity), a naive fallback strategy is used that moves the balls + uniformly along the line of centers until they're separated by an amount + ``spacer``. + + Algorithm: + 1. If both balls are non-translating, apply fallback + 2. Otherwise, calculate quadratic coefficients for separation equation + 3. Solve for time offset that achieves target separation + 4. Move balls to corrected positions: r_new = r + t * v + 5. If midpoint shifts more than 5x spacer, apply fallback + + Returns: + tuple[Ball, Ball]: + ``ball1`` and ``ball2`` modified in place with adjusted positions. + """ + r1 = ball1.state.rvw[0] + r2 = ball2.state.rvw[0] + v1 = ball1.state.rvw[1] + v2 = ball2.state.rvw[1] + + spacer = const.MIN_DIST + + if ( + ball1.state.s in const.nontranslating + and ball2.state.s in const.nontranslating + ): + r1_corrected, r2_corrected = self._apply_fallback_positioning( + ball1, ball2, r1, r2, spacer + ) + else: + Bx = v2[0] - v1[0] + By = v2[1] - v1[1] + Bz = v2[2] - v1[2] + Cx = r2[0] - r1[0] + Cy = r2[1] - r1[1] + Cz = r2[2] - r1[2] + alpha = Bx * Bx + By * By + Bz * Bz + beta = 2 * Bx * Cx + 2 * By * Cy + 2 * Bz * Cz + gamma = ( + Cx * Cx + + Cy * Cy + + Cz * Cz + - (2 * ball1.params.R + spacer) * (2 * ball1.params.R + spacer) + ) + roots_complex = ptmath.roots.quadratic.solve_complex(alpha, beta, gamma) + + imag_mag = np.abs(roots_complex.imag) + real_mag = np.abs(roots_complex.real) + keep = (imag_mag / real_mag) < 1e-3 + roots = roots_complex[keep].real + t = roots[np.abs(roots).argmin()] + + r1_corrected = r1 + t * v1 + r2_corrected = r2 + t * v2 + + midpoint = (r1 + r2) / 2 + midpoint_corrected = (r1_corrected + r2_corrected) / 2 + if ptmath.norm3d(midpoint - midpoint_corrected) > 5 * spacer: + r1_corrected, r2_corrected = self._apply_fallback_positioning( + ball1, ball2, r1, r2, spacer + ) + + ball1.state.rvw[0] = r1_corrected + ball2.state.rvw[0] = r2_corrected + + return ball1, ball2 - This makes a correction such that if the balls are not _exactly_ 2*R apart, they - are moved equally along their line of centers such that they are. Then, to avoid - downstream float precision round-off errors, a small epsilon of additional - distance (constants.EPS_SPACE) is put between them, ensuring the balls are - non-intersecting. + def resolve_continually_touching( + self, ball1: Ball, ball2: Ball + ) -> tuple[Ball, Ball]: + """Prevent repeated collision detection for nearly-touching balls moving in unison. + + This method is called to handle rare cases where balls are moving with very + similar velocities. This can happen in some edge cases when frozen balls in a + perfect line are given energy along their line, (e.g. Newton's cradle). Without + intervention, the balls repeatedly trigger events microseconds apart that stall + progression of the simulation, sometimes indefinitely, via an explosion of + events. + + This is an unfortunate consequence of modeling non-instantaneous multibody + collisions using instantaneous pairwise collisions, and resolving it requires + some phenomonelogical intervention that hopefully appears to be realistic, + despite it not being grounded in theory. + + The solution is applied in this method, and uses a momentum transfer mechanism: + the "chased" ball (slower in the line-of-centers direction) steals a fraction of + the "chaser's" radial momentum. This creates gradual separation over time, + preventing the balls from triggering repeated collision events while maintaining + physically plausible behavior. + + Algorithm: + 1. Projects velocities onto line of centers to get radial components + 2. If radial relative velocity is below threshold (< 1mm/s): + - Identifies which ball is "chasing" (higher radial velocity) + - Chased ball steals 10% of chaser's radial momentum + - Chaser loses this momentum, chased gains it + 3. Tangential velocity components remain unchanged + + Args: + ball1: First ball in the collision + ball2: Second ball in the collision + + Returns: + Modified ball1 and ball2 with adjusted velocities + + Notes: + - Practically speaking, this is a no-op method for all but the most + contrived simulation conditions. For one such condition, see + `sandbox/newtons_cradle.py` """ - r1, r2 = ball1.state.rvw[0], ball2.state.rvw[0] - n = ptmath.unit_vector(r2 - r1) + r1 = ball1.state.rvw[0] + r2 = ball2.state.rvw[0] + v1 = ball1.state.rvw[1] + v2 = ball2.state.rvw[1] + + v1_speed = ptmath.norm3d(v1) + v2_speed = ptmath.norm3d(v2) + both_moving = v1_speed > 0 and v2_speed > 0 + + if not both_moving: + return ball1, ball2 + + theft_fraction = 0.10 + velocity_similarity_threshold = 0.9 - correction = 2 * ball1.params.R - ptmath.norm3d(r2 - r1) + const.EPS_SPACE - ball2.state.rvw[0] += correction / 2 * n - ball1.state.rvw[0] -= correction / 2 * n + line_of_centers = ptmath.unit_vector(r2 - r1) + + # Velocities projected onto the line of centers (loc). + v1_loc = np.dot(v1, line_of_centers) + v2_loc = np.dot(v2, line_of_centers) + + cosine_similarity = np.dot(v1, v2) / (v1_speed * v2_speed) + velocities_aligned = cosine_similarity > velocity_similarity_threshold + + if abs(v2_loc - v1_loc) < 0.01 and velocities_aligned: + if v1_loc > v2_loc: + chaser_loc_vel = v1_loc + ball1_is_chaser = True + else: + chaser_loc_vel = v2_loc + ball1_is_chaser = False + + # Chased ball steals fraction of chaser's line of centers momentum + # FIXME: We assume equal mass, so transfer velocity directly + stolen_loc_velocity = chaser_loc_vel * theft_fraction + + if ball1_is_chaser: + v1_loc_new = v1_loc - stolen_loc_velocity + v2_loc_new = v2_loc + stolen_loc_velocity + else: + v1_loc_new = v1_loc + stolen_loc_velocity + v2_loc_new = v2_loc - stolen_loc_velocity + + v1_corrected = v1 - v1_loc * line_of_centers + v1_loc_new * line_of_centers + v2_corrected = v2 - v2_loc * line_of_centers + v2_loc_new * line_of_centers + + momentum_before = v1 + v2 + momentum_after = v1_corrected + v2_corrected + assert np.allclose(momentum_before, momentum_after, rtol=1e-10) + + ball1.state.rvw[1] = v1_corrected + ball2.state.rvw[1] = v2_corrected return ball1, ball2 @@ -51,8 +233,10 @@ def resolve( ball2 = ball2.copy() ball1, ball2 = self.make_kiss(ball1, ball2) + ball1, ball2 = self.solve(ball1, ball2) + ball1, ball2 = self.resolve_continually_touching(ball1, ball2) - return self.solve(ball1, ball2) + return ball1, ball2 @abstractmethod def solve(self, ball1: Ball, ball2: Ball) -> tuple[Ball, Ball]: diff --git a/pooltool/physics/resolve/ball_cushion/core.py b/pooltool/physics/resolve/ball_cushion/core.py index 6c532df0..cd7ad45c 100644 --- a/pooltool/physics/resolve/ball_cushion/core.py +++ b/pooltool/physics/resolve/ball_cushion/core.py @@ -3,7 +3,6 @@ import numpy as np -import pooltool.constants as const import pooltool.ptmath as ptmath from pooltool.objects.ball.datatypes import Ball from pooltool.objects.table.components import ( @@ -57,7 +56,7 @@ def make_kiss(self, ball: Ball, cushion: LinearCushionSegment) -> Ball: This makes a correction such that if the ball is not a distance R from the cushion, the ball is moved along the normal such that it is. To avoid downstream float precision round-off error, a small epsilon of additional distance - (constants.EPS_SPACE) is put between them, ensuring the cushion and ball are + (``spacer``) is put between them, ensuring the cushion and ball are separated post-resolution. """ normal = cushion.get_normal_xy(ball.xyz) @@ -72,10 +71,10 @@ def make_kiss(self, ball: Ball, cushion: LinearCushionSegment) -> Ball: ) c[2] = ball.state.rvw[0, 2] + spacer = 1e-9 + # Move the ball to exactly meet the cushion - correction = ( - ball.params.R - ptmath.norm3d(ball.state.rvw[0] - c) + const.EPS_SPACE - ) + correction = ball.params.R - ptmath.norm3d(ball.state.rvw[0] - c) + spacer ball.state.rvw[0] -= correction * normal return ball @@ -107,7 +106,7 @@ def make_kiss(self, ball: Ball, cushion: CircularCushionSegment) -> Ball: This makes a correction such that if the ball is not a distance R from the cushion, the ball is moved along the normal such that it is. To avoid downstream float precision round-off error, a small epsilon of additional distance - (constants.EPS_SPACE) is put between them, ensuring the cushion and ball are + (``spacer``) is put between them, ensuring the cushion and ball are separated post-resolution. """ normal = cushion.get_normal_xy(ball.xyz) @@ -115,12 +114,14 @@ def make_kiss(self, ball: Ball, cushion: CircularCushionSegment) -> Ball: # orient the normal so it points away from playing surface normal = normal if np.dot(normal, ball.state.rvw[1]) > 0 else -normal + spacer = 1e-9 + c = np.array([cushion.center[0], cushion.center[1], ball.state.rvw[0, 2]]) correction = ( ball.params.R + cushion.radius - ptmath.norm3d(ball.state.rvw[0] - c) - - const.EPS_SPACE + - spacer ) ball.state.rvw[0] += correction * normal diff --git a/pooltool/physics/resolve/transition/__init__.py b/pooltool/physics/resolve/transition/__init__.py index eebb63c5..5c5eee19 100644 --- a/pooltool/physics/resolve/transition/__init__.py +++ b/pooltool/physics/resolve/transition/__init__.py @@ -15,6 +15,8 @@ from pooltool.objects.ball.datatypes import Ball from pooltool.physics.resolve.models import BallTransitionModel +_TOLERANCE = 1e-12 + class BallTransitionStrategy(Protocol): """Ball transition models must satisfy this protocol""" @@ -47,8 +49,8 @@ def resolve(self, ball: Ball, transition: EventType, inplace: bool = False) -> B # angular velocity components are nearly 0. Then set them to exactly 0. v = ball.state.rvw[1] w = ball.state.rvw[2] - assert (np.abs(v) < const.EPS_SPACE).all() - assert (np.abs(w[:2]) < const.EPS_SPACE).all() + assert (np.abs(v) < _TOLERANCE).all() + assert (np.abs(w[:2]) < _TOLERANCE).all() ball.state.rvw[1, :] = [0.0, 0.0, 0.0] ball.state.rvw[2, :2] = [0.0, 0.0] @@ -58,8 +60,8 @@ def resolve(self, ball: Ball, transition: EventType, inplace: bool = False) -> B # set them to exactly 0. v = ball.state.rvw[1] w = ball.state.rvw[2] - assert (np.abs(v) < const.EPS_SPACE).all() - assert (np.abs(w) < const.EPS_SPACE).all() + assert (np.abs(v) < _TOLERANCE).all() + assert (np.abs(w) < _TOLERANCE).all() ball.state.rvw[1, :] = [0.0, 0.0, 0.0] ball.state.rvw[2, :] = [0.0, 0.0, 0.0] diff --git a/pooltool/ptmath/roots/_quartic_numba.py b/pooltool/ptmath/roots/_quartic_numba.py index 52c28fcb..3bf60d6c 100644 --- a/pooltool/ptmath/roots/_quartic_numba.py +++ b/pooltool/ptmath/roots/_quartic_numba.py @@ -1,8 +1,14 @@ -"""1:1 exact translation of the "1010" quartic root-finding algorithm. +"""Translation of the "1010" quartic root-finding algorithm with modifications. The original implementation is written in C, and this module was written by Claude Code. -The 1:1 correspondence has been tested to floating point precision on a test of 100,000 -difficult to determine quartics. + +Modifications from the original algorithm: + - The threshold for falling back to an alternative factorization method has been + loosened by a safety factor (d2_safety_factor). The original threshold was too + strict for certain edge cases where the discriminant d2 is very small but nonzero, + causing the algorithm to produce duplicate roots instead of four distinct roots. + This is particularly relevant for Newton's cradle-like simulations where balls + move with very similar velocities. Solve speed: @@ -58,6 +64,7 @@ cubic_rescal_fact = 3.488062113727083e102 quart_rescal_fact = 7.156344627944542e76 macheps = 2.2204460492503131e-16 +d2_safety_factor = 100 @jit(nopython=True, cache=const.use_numba_cache) @@ -630,7 +637,8 @@ def solve(a: float, b: float, c: float, d: float, e: float) -> NDArray[np.comple realcase_0 = -1 if realcase_0 == -1 or ( - abs(d2) <= macheps * max(abs(2.0 * b_p / 3.0), abs(phi0), l1 * l1) + abs(d2) + <= d2_safety_factor * macheps * max(abs(2.0 * b_p / 3.0), abs(phi0), l1 * l1) ): d3 = d_p - l3 * l3 if realcase_0 == 1: diff --git a/pooltool/ptmath/roots/quadratic.py b/pooltool/ptmath/roots/quadratic.py index 2f8c885b..0c770832 100644 --- a/pooltool/ptmath/roots/quadratic.py +++ b/pooltool/ptmath/roots/quadratic.py @@ -1,7 +1,9 @@ +import cmath import math import numpy as np from numba import jit +from numpy.typing import NDArray import pooltool.constants as const @@ -19,3 +21,50 @@ def solve(a: float, b: float, c: float) -> tuple[float, float]: u1 = (-bp - delta**0.5) / a u2 = -u1 - b / a return u1, u2 + + +# TODO: In the branch `3d`, which will eventually be merged into main, this function has +# replaced `solve`. +@jit(nopython=True, cache=const.use_numba_cache) +def solve_complex(a: float, b: float, c: float) -> NDArray[np.complex128]: + _a = complex(a) + _b = complex(b) + _c = complex(c) + + roots = np.full(2, np.nan, dtype=np.complex128) + + if abs(_a) != 0: + # Quadratic case + d = _b * _b - 4 * _a * _c + sqrt_d = cmath.sqrt(d) + + # Sign trick to reduce catastrophic cancellation + sign_b = 1.0 if _b.real >= 0 else -1.0 + + r1_num = -_b - sign_b * sqrt_d + r1_den = 2 * _a + + # Fallback if numerator is tiny + if abs(r1_num) < 1e-14 * abs(r1_den): + r1_num = -_b + sign_b * sqrt_d + + r1 = r1_num / r1_den + + # Use product identity for x2 + if abs(r1) < 1e-14: + r2 = (-_b + sqrt_d) / (2 * _a) + else: + r2 = (_c / _a) / r1 + + roots[0] = r1 + roots[1] = r2 + return roots + + if abs(_b) != 0: + # Linear case + r1 = -_c / _b + roots[0] = r1 + return roots + + # Equation is just c=0. Either zero or infinite solutions. Returns nans + return roots diff --git a/pooltool/ptmath/utils.py b/pooltool/ptmath/utils.py index 4cb1f746..253288d5 100644 --- a/pooltool/ptmath/utils.py +++ b/pooltool/ptmath/utils.py @@ -288,7 +288,7 @@ def point_on_line_closest_to_point( @jit(nopython=True, cache=const.use_numba_cache) def squared_norm3d(vec: NDArray[np.float64]) -> float: """Calculate the squared norm of a 3D vector""" - return vec[0] ** 2 + vec[1] ** 2 + vec[2] ** 2 + return vec[0] * vec[0] + vec[1] * vec[1] + vec[2] * vec[2] @jit(nopython=True, cache=const.use_numba_cache) @@ -312,7 +312,7 @@ def norm3d(vec: NDArray[np.float64]) -> float: @jit(nopython=True, cache=const.use_numba_cache) def squared_norm2d(vec: NDArray[np.float64]) -> float: """Calculate the squared norm of a 2D vector""" - return vec[0] ** 2 + vec[1] ** 2 + return vec[0] * vec[0] + vec[1] * vec[1] @jit(nopython=True, cache=const.use_numba_cache) @@ -409,9 +409,13 @@ def get_ball_energy(rvw: NDArray[np.float64], R: float, m: float) -> float: def is_overlapping( - rvw1: NDArray[np.float64], rvw2: NDArray[np.float64], R1: float, R2: float + rvw1: NDArray[np.float64], + rvw2: NDArray[np.float64], + R1: float, + R2: float, + min_spacer: float = 0.0, ) -> bool: - return norm3d(rvw1[0] - rvw2[0]) < (R1 + R2) + return norm3d(rvw1[0] - rvw2[0]) < (R1 + R2 + min_spacer) @jit(nopython=True, cache=const.use_numba_cache) diff --git a/sandbox/newtons_cradle.py b/sandbox/newtons_cradle.py new file mode 100644 index 00000000..6f98cb6f --- /dev/null +++ b/sandbox/newtons_cradle.py @@ -0,0 +1,64 @@ +import argparse + +import numpy as np + +import pooltool as pt +from pooltool.layouts import BallPos, Jump, ball_cluster_blueprint, generate_layout + + +def main(): + parser = argparse.ArgumentParser( + description="Simulate an imperfect Newton's cradle. WARNING: if angle-variation " + "is low and n-balls is high, the event count will skyrocket.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--n-balls", type=int, default=4, help="Number of balls") + parser.add_argument( + "--angle-variation", type=float, default=0.3, help="Max angle offset (degrees)" + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed") + args = parser.parse_args() + + rng = np.random.default_rng(args.seed) + + table = pt.Table.default() + ball_params = pt.BallParams.default() + + jump_sequence = [ + ( + Jump.ANGLE(90 + rng.uniform(-args.angle_variation, args.angle_variation)), + {str(i + 2)}, + ) + for i in range(args.n_balls - 1) + ] + + blueprint = ball_cluster_blueprint( + seed=BallPos([], (0.5, 0.4), {"1"}), + jump_sequence=jump_sequence, + ) + + cue = BallPos([], (0.5, 0.1), {"cue"}) + blueprint.append(cue) + + balls = generate_layout( + blueprint, + table, + ballset=pt.objects.BallSet("pooltool_pocket"), + ball_params=ball_params, + spacing_factor=0, + ) + + system = pt.System( + balls=balls, + table=table, + cue=pt.Cue.default(), + ) + + system.strike(V0=2, phi=pt.aim.at_ball(system, "1", cut=0)) + pt.simulate(system, inplace=True) + print(len(system.events)) + pt.show(system) + + +if __name__ == "__main__": + main() diff --git a/tests/evolution/event_based/test_introspection.py b/tests/evolution/event_based/test_introspection.py index 216924e0..186aa3b0 100644 --- a/tests/evolution/event_based/test_introspection.py +++ b/tests/evolution/event_based/test_introspection.py @@ -35,8 +35,8 @@ def test_selected_event_in_all_possible_events(): all_events = snapshot.get_prospective_events() first_event = all_events[0] - assert snapshot.selected_event.event_type == first_event.event_type - assert snapshot.selected_event.time == first_event.time + assert snapshot.next_event.event_type == first_event.event_type + assert snapshot.next_event.time == first_event.time def test_pre_evolve_equals_snapshot_system(): @@ -59,7 +59,7 @@ def test_post_evolve_advances_time(): for step in range(len(seq)): snapshot = seq[step] - event = snapshot.selected_event + event = snapshot.next_event post_evolve = snapshot.post_evolve_system(event) assert post_evolve.t == event.time @@ -74,9 +74,7 @@ def test_post_resolve_of_n_equals_pre_evolve_of_n_plus_1(): current_snapshot = seq[step] next_snapshot = seq[step + 1] - post_resolve = current_snapshot.post_resolve_system( - current_snapshot.selected_event - ) + post_resolve = current_snapshot.post_resolve_system(current_snapshot.next_event) pre_evolve_next = next_snapshot.pre_evolve_system() assert post_resolve == pre_evolve_next @@ -89,7 +87,7 @@ def test_system_state_progression(): for step in range(len(seq)): snapshot = seq[step] - event = snapshot.selected_event + event = snapshot.next_event pre_evolve = snapshot.pre_evolve_system() post_evolve = snapshot.post_evolve_system(event) diff --git a/tests/evolution/event_based/test_simulate.py b/tests/evolution/event_based/test_simulate.py index d54fa2d5..16aa3ebb 100644 --- a/tests/evolution/event_based/test_simulate.py +++ b/tests/evolution/event_based/test_simulate.py @@ -4,6 +4,7 @@ import pooltool.constants as const import pooltool.ptmath as ptmath +from pooltool import aim, events from pooltool.events import EventType, ball_ball_collision, ball_pocket_collision from pooltool.evolution.event_based.cache import CollisionCache from pooltool.evolution.event_based.simulate import ( @@ -14,6 +15,8 @@ ) from pooltool.evolution.event_based.solve import ball_ball_collision_time from pooltool.objects import Ball, BilliardTableSpecs, Cue, Table +from pooltool.objects.ball.params import BallParams +from pooltool.objects.ball.sets import BallSet from pooltool.ptmath.roots import quadratic from pooltool.system import System from tests.evolution.event_based.test_data import TEST_DIR @@ -403,16 +406,15 @@ def true_time_to_collision(eps, V0, mu_r, g): assert diff < 10e-12 # Less than 10 femptosecond difference -def test_no_ball_ball_collisions_for_intersecting_balls(): - """Two already intersecting balls don't collide +def test_ball_ball_collision_for_intersecting_balls(): + """Two already intersecting balls collide. - In this instance, no further collision is detected because the balls are already - intersecting. Otherwise perpetual internal collisions occur, keeping the two balls - locked. + Previously, intersecting balls were prevented from colliding to avoid perpetual + internal collisions. Now, with the improved make_kiss implementation, intersecting + balls are properly separated and collide normally. - This test doesn't make sure that balls don't intersect, it tests the safeguard that - prevents already intersecting balls from colliding with their internal walls, which - keeps them intersected like links in a chain. + This test verifies that intersecting balls are detected as a collision at time == + shot.t , - ~ , , - ~ , , ' ' ,, ' ' , @@ -446,8 +448,10 @@ def test_no_ball_ball_collisions_for_intersecting_balls(): # The cue is truly rolling _assert_rolling(system.balls["cue"].state.rvw, system.balls["cue"].params.R) - assert get_next_event(system).event_type != EventType.BALL_BALL - assert get_next_ball_ball_collision(system, CollisionCache()).time == np.inf + assert get_next_event(system).event_type == EventType.BALL_BALL + collision_event = get_next_ball_ball_collision(system, CollisionCache()) + assert collision_event.time != np.inf + assert collision_event.time == 0 def test_ball_history_immutability(): @@ -539,3 +543,140 @@ def test_stick_ball_event_detection(): assert simulated.events[1].event_type == EventType.STICK_BALL assert simulated.events[1].time == 0 assert simulated.events[2].time > 0 + + +def test_newtons_cradle_backspin(): + """Test Newtons's cradle when incoming ball has backspin. + + This is a easier simulation scenario than test_newtons_cradle_rolling_spin because + the incoming ball's spin, after the first collision, pulls it away from the line of + centers, avoiding a chain of followup collisions that ultimately push the whole line + of balls forwards. + """ + # Create a newton's cradle system, where balls 1, 2, and 3 are in a line. the cue + # ball is placed on the same line of centers some distance away and has momentum + # towards the 1 ball. + table = Table.default() + ball_radius = BallParams.default().R + + balls = [ + Ball.create( + str(i + 1), + xy=(0.5 * table.w, 0.4 * table.l + 2 * i * ball_radius), + ballset=BallSet("pooltool_pocket"), + ) + for i in range(3) + ] + + balls.append( + Ball.create( + "cue", + xy=(0.5 * table.w, 0.1 * table.l), + ballset=BallSet("pooltool_pocket"), + ) + ) + + system = System( + balls=balls, + table=table, + cue=Cue.default(), + ) + + system.strike(V0=2, b=-0.5, phi=aim.at_ball(system, "1", cut=0)) + system = simulate(system, max_events=10) + + # Define a 1-millisecond window that starts from the first collision. + first_collision = events.filter_type(system.events, EventType.BALL_BALL)[0] + start_time = first_collision.time - 1e-9 + end_time = first_collision.time + 1e-3 + + collision_chain = events.filter_events( + system.events, + events.by_time(start_time, after=True), + events.by_time(end_time, after=False), + events.by_type(EventType.BALL_BALL), + ) + + # The last ball should be involved in exactly one collision, since a near-elastic + # collision should effectively halt the second-last ball while sending the last ball + # flying away, precluding followup collisions. + assert len(events.filter_ball(collision_chain, "3")) == 1 + + # We don't assert how the second wave collision waves play out, but we know the + # collision sequence up until the last ball is hit: a chain of events where momentum + # is transfered sequentially through the line of balls. + expected_collision_order = [{"cue", "1"}, {"1", "2"}, {"2", "3"}] + + assert len(collision_chain) >= len(expected_collision_order) + + for idx, expected in enumerate(expected_collision_order): + actual = set(collision_chain[idx].ids) + assert expected == actual + + +def test_newtons_cradle_rolling_spin(): + """Test Newtons's cradle when incoming ball has rolling spin. + + This is a harder simulation scenario than test_newtons_cradle_backspin because the + incoming ball's spin, after the first collision, triggers many followup collisions + that ultimately push the who line of balls forwards. + """ + # Create a newton's cradle system, where balls 1, 2, and 3 are in a line. the cue + # ball is placed on the same line of centers some distance away and has momentum + # towards the 1 ball. + table = Table.default() + ball_radius = BallParams.default().R + + balls = [ + Ball.create( + str(i + 1), + xy=(0.5 * table.w, 0.4 * table.l + 2 * i * ball_radius), + ballset=BallSet("pooltool_pocket"), + ) + for i in range(3) + ] + + balls.append( + Ball.create( + "cue", + xy=(0.5 * table.w, 0.1 * table.l), + ballset=BallSet("pooltool_pocket"), + ) + ) + + system = System( + balls=balls, + table=table, + cue=Cue.default(), + ) + + system.strike(V0=2, phi=aim.at_ball(system, "1", cut=0)) + simulate(system, inplace=True, max_events=10) + + # Define a 1-millisecond window that starts from the first collision. + first_collision = events.filter_type(system.events, EventType.BALL_BALL)[0] + start_time = first_collision.time - 1e-9 + end_time = first_collision.time + 1e-3 + + collision_chain = events.filter_events( + system.events, + events.by_time(start_time, after=True), + events.by_time(end_time, after=False), + events.by_type(EventType.BALL_BALL), + ) + + # The last ball should be involved in exactly one collision, since a near-elastic + # collision should effectively halt the second-last ball while sending the last ball + # flying away, precluding followup collisions. + assert len(events.filter_ball(collision_chain, "3")) == 1 + + # We don't assert how the second wave collision waves play out, but we know the + # collision sequence up until the last ball is hit: a chain of events where momentum + # is transfered sequentially through the line of balls. + expected_collision_order = [{"cue", "1"}, {"1", "2"}, {"2", "3"}] + + assert len(collision_chain) >= len(expected_collision_order) + + for idx, expected in enumerate(expected_collision_order): + actual = set(collision_chain[idx].ids) + assert expected == actual