diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml new file mode 100644 index 0000000..10d64e3 --- /dev/null +++ b/.github/workflows/unit-tests.yml @@ -0,0 +1,33 @@ +name: Unit Tests + +on: + push: + branches: + - main + pull_request: + +jobs: + tests: + runs-on: ubuntu-latest + + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + cache: pip + cache-dependency-path: | + pyproject.toml + requirements.txt + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install . + python -m pip install pytest + + - name: Run unit tests + run: pytest tests/unit diff --git a/README.md b/README.md index 1c365c8..e0e99e7 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,13 @@ A dynamical system is in a _viable_ state if there exist control inputs that all ## Installation We recommend using a virtual environment. Note, `GPy` is only required for the safe learning examples, and can be safely removed from the requirements. Install from your terminal with -conda (recommended): +[uv](https://docs.astral.sh/uv/getting-started/installation/) (recommended): +`uv venv .venv` +`source .venv/bin/activate` +`uv pip install -r requirements.txt` +`uv pip install -e .` + +conda: `conda create -n vibly python=3.9` `conda activate vibly` then follow the same instructions as for pip: @@ -45,7 +51,7 @@ from measure import active_sampling Examples are shown in the `demos` folder. We recommend starting with `slip_demo.py`, and then `computeQ_slip.py`. The `viability` package contains: - `compute_Q_map`: a utility to compute a gridded transition map for N-dimensional systems. Note, this can be computationally intensives (it is essentially brute-forcing an N-dimensional problem). It typically works reasonably well for up to ~4 dimensions. -- `parcompute_Q_map`: same as above, but parallelized. You typically want to use this, unless running a debugger. +- `compute_Q_map(..., parallel=True)`: parallelised version via multiprocessing; typically preferable unless debugging. - `compute_QV`: computes the viability kernel and viable set to within conservative discrete approximation, using the grid generated by `compute_Q_map`. - `get_feasibility_mask`: this can be used to exclude parts of the grid which are infeasible (i.e. are not physically meaningful) - `project_Q2S`: Apply an operator (default is an orthogonal projection) from state-action space to state space. Used to compute measures. diff --git a/demos/TAC11/computeQ_satellite.py b/demos/TAC11/computeQ_satellite.py index dc7241c..0b84760 100644 --- a/demos/TAC11/computeQ_satellite.py +++ b/demos/TAC11/computeQ_satellite.py @@ -35,9 +35,18 @@ grids = {"states": s_grid, "actions": a_grid} tictoc.tic() - Q_map, Q_F, Q_coords = vibly.parcompute_Q_mapC( - grids, p_map, verbose=1, check_grid=False, keep_coords=True + result = vibly.compute_Q_map( + grids, + p_map, + verbose=1, + check_grid=False, + keep_coords=True, + parallel=True, + bin_mode="nearest", ) + Q_map = result.q_map + Q_F = result.q_fail + Q_coords = result.q_reached # Q_map, Q_F, Q_on_grid = vibly.compute_Q_map(grids, p_map, check_grid=True) # * compute_QV computes the viable set and viability kernel diff --git a/demos/computeQ_daslip.py b/demos/computeQ_daslip.py index 4746875..ae50ff2 100644 --- a/demos/computeQ_daslip.py +++ b/demos/computeQ_daslip.py @@ -71,7 +71,9 @@ a_grid = (np.linspace(20 / 180 * np.pi, 60 / 180 * np.pi, 25),) grids = {"states": s_grid, "actions": a_grid} - Q_map, Q_F = vibly.parcompute_Q_map(grids, p_map, verbose=2) + result = vibly.compute_Q_map(grids, p_map, verbose=2, parallel=True) + Q_map = result.q_map + Q_F = result.q_fail Q_V, S_V = vibly.compute_QV(Q_map, grids) S_M = vibly.project_Q2S(Q_V, grids, proj_opt=np.mean) Q_M = vibly.map_S2Q(Q_map, S_M, s_grid=s_grid, Q_V=Q_V) diff --git a/demos/computeQ_hovership.py b/demos/computeQ_hovership.py index d9c21fb..bec7418 100644 --- a/demos/computeQ_hovership.py +++ b/demos/computeQ_hovership.py @@ -39,15 +39,16 @@ # * compute_Q_map computes a gridded transition map, `Q_map`, which is used # * a look-up table for computing viable sets. - # * Switch to `parcompute_Q_map` to use parallelized version - # * (requires multiprocessing module) + # * Enable `parallel=True` to use the multiprocessing version # * Q_F is a grid marking all failing state-action pairs # * Q_on_grid is a helper grid, which marks if a state has not moved at all # * this is used to catch corner cases, and is not important for most # * systems with interesting dynamics # * setting `check_grid` to False will omit Q_on_grid - Q_map, Q_F, Q_on_grid = vibly.parcompute_Q_map(grids, p_map, check_grid=True) - # Q_map, Q_F, Q_on_grid = vibly.compute_Q_map(grids, p_map, check_grid=True) + result = vibly.compute_Q_map(grids, p_map, check_grid=True, parallel=True) + Q_map = result.q_map + Q_F = result.q_fail + Q_on_grid = result.q_on_grid # * compute_QV computes the viable set and viability kernel Q_V, S_V = vibly.compute_QV(Q_map, grids, ~Q_F, Q_on_grid=Q_on_grid) diff --git a/demos/computeQ_lip.py b/demos/computeQ_lip.py index 8774a2f..33c3bc6 100644 --- a/demos/computeQ_lip.py +++ b/demos/computeQ_lip.py @@ -45,7 +45,9 @@ grids = {"states": s_grid, "actions": a_grid} tictoc.tic() - Q_map, Q_F = vibly.parcompute_Q_map(grids, p_map, verbose=1) + result = vibly.compute_Q_map(grids, p_map, verbose=1, parallel=True) + Q_map = result.q_map + Q_F = result.q_fail time_elapsed = tictoc.toc() print("time elapsed: " + str(time_elapsed / 60.0)) # * compute_QV computes the viable set and viability kernel diff --git a/demos/computeQ_nslip.py b/demos/computeQ_nslip.py new file mode 100644 index 0000000..40814b0 --- /dev/null +++ b/demos/computeQ_nslip.py @@ -0,0 +1,67 @@ +import numpy as np +import matplotlib.pyplot as plt +from models import nslip +import viability as vibly + + +if __name__ == "__main__": + p = { + "mass": 80.0, + "stiffness": 705.0, + "resting_angle": 17 / 18 * np.pi, + "gravity": 9.81, + "angle_of_attack": 1 / 5 * np.pi, + "upper_leg": 0.5, + "lower_leg": 0.5, + } + + x0 = np.array([0.0, 0.85, 5.5, 0.0, 0.0, 0.0, 0.0]) + x0 = nslip.reset_leg(x0, p) + p["x0"] = x0 + p["total_energy"] = nslip.compute_total_energy(x0, p) + + p_map = nslip.p_map + p_map.p = p + p_map.x = x0 + p_map.sa2xp = nslip.sa2xp + p_map.xp2s = nslip.xp2s + + s_grid = np.linspace(0.1, 1.0, 61) + s_grid = (s_grid[:-1],) + a_grid = (np.linspace(-10 / 180 * np.pi, 70 / 180 * np.pi, 61),) + grids = {"states": s_grid, "actions": a_grid} + + result = vibly.compute_Q_map(grids, p_map, parallel=True) + Q_map = result.q_map + Q_F = result.q_fail + Q_V, S_V = vibly.compute_QV(Q_map, grids) + S_M = vibly.project_Q2S(Q_V, grids, proj_opt=np.mean) + Q_M = vibly.map_S2Q(Q_map, S_M, s_grid, Q_V=Q_V) + + import pickle + import os + + filename = "nslip_map.pickle" + + if os.path.exists("data"): + path_to_file = "data/dynamics/" + else: + path_to_file = "../data/dynamics/" + if not os.path.exists(path_to_file): + os.makedirs(path_to_file) + + data2save = { + "grids": grids, + "Q_map": Q_map, + "Q_F": Q_F, + "Q_V": Q_V, + "Q_M": Q_M, + "S_M": S_M, + "p": p, + "x0": x0, + } + with open(path_to_file + filename, "wb") as outfile: + pickle.dump(data2save, outfile) + + plt.imshow(Q_map, origin="lower") + plt.show() diff --git a/demos/computeQ_slip.py b/demos/computeQ_slip.py index 7b4d06c..3a413e5 100644 --- a/demos/computeQ_slip.py +++ b/demos/computeQ_slip.py @@ -28,7 +28,9 @@ a_grid = (np.linspace(-10 / 180 * np.pi, 70 / 180 * np.pi, 161),) grids = {"states": s_grid, "actions": a_grid} # Q_map, Q_F = vibly.compute_Q_map(grids, p_map) - Q_map, Q_F = vibly.parcompute_Q_map(grids, p_map) + result = vibly.compute_Q_map(grids, p_map, parallel=True) + Q_map = result.q_map + Q_F = result.q_fail Q_V, S_V = vibly.compute_QV(Q_map, grids) S_M = vibly.project_Q2S(Q_V, grids, proj_opt=np.mean) Q_M = vibly.map_S2Q(Q_map, S_M, s_grid, Q_V=Q_V) diff --git a/demos/computeQ_spaceship4.py b/demos/computeQ_spaceship4.py index 95d4b95..80fa5d5 100644 --- a/demos/computeQ_spaceship4.py +++ b/demos/computeQ_spaceship4.py @@ -48,9 +48,16 @@ # * Q_on_grid is a helper grid, which marks if a state has not moved # * this is used to catch corner cases, and is not important for most systems # * setting `check_grid` to False will omit Q_on_grid - Q_map, Q_F, Q_on_grid = vibly.parcompute_Q_map( - grids, p_map, check_grid=True, verbose=1 + result = vibly.compute_Q_map( + grids, + p_map, + check_grid=True, + verbose=1, + parallel=True, ) + Q_map = result.q_map + Q_F = result.q_fail + Q_on_grid = result.q_on_grid # * compute_QV computes the viable set and viability kernel Q_V, S_V = vibly.compute_QV(Q_map, grids, ~Q_F, Q_on_grid=Q_on_grid) diff --git a/demos/damping_study/compute_measure_damping.py b/demos/damping_study/compute_measure_damping.py index f3b4b6f..a1e559a 100644 --- a/demos/damping_study/compute_measure_damping.py +++ b/demos/damping_study/compute_measure_damping.py @@ -75,7 +75,9 @@ def compute_viability(x0, p, name, visualise=False): grids = {"states": s_grid, "actions": a_grid} # * compute transition matrix and boolean matrix of failures - Q_map, Q_F = vibly.parcompute_Q_map(grids, p_map, verbose=1) + result = vibly.compute_Q_map(grids, p_map, verbose=1, parallel=True) + Q_map = result.q_map + Q_F = result.q_fail # * compute viable sets Q_V, S_V = vibly.compute_QV(Q_map, grids) diff --git a/models/acrobot.py b/models/acrobot.py deleted file mode 100644 index 344ca00..0000000 --- a/models/acrobot.py +++ /dev/null @@ -1,120 +0,0 @@ -import numpy as np - -""" -space attempting to reconnoitre the surface of a planet. -Must ensure not to go to the dark side of the planet. -x_{k+1} = map(x_k, p) -x: (x1, x2) -x1: altitude -x2: longitude -p: dict of parameters. For convenience, actions are also stored here. -""" - -# map: x_k+1, failed = map - - -def wind(x, p): - # return -0.5*x[1] - # stretch = (p['x0_upper_bound']-p['x0_lower_bound'])*2*np.pi - return p["wind"] * np.sin(x[0] * np.pi) - - -def gravity(x, p): - return np.max([0, x[0]]) * p["gravity"] - - -def mass_matrix(x, p): - # unpack - m1 = p["m1"] - m2 = p["m2"] - l1 = p["l1"] - l2 = p["l2"] - - a = m1 * l1**2 + m2 * l1**2 - b = m2 * l2**2 - c = m2 * l1 * l2 - - return np.array( - [ - [a + b + 2 * c * np.cos(x[1]), b + c * np.cos(x[1])], - [b + c * np.cos(x[1]), b], - ] - ) - - -def coriolis(x, p): - m2 = p["m2"] - l1 = p["l1"] - l2 = p["l2"] - - c = m2 * l1 * l2 - - return np.array( - [ - [-c * np.sin(x[1]) * x[3], -c * np.sin(x[0] + x[1])], - [c * np.sin(x[1]) * x[2], 0], - ] - ) - - -def gravitational(x, p): - m1 = p["m1"] - m2 = p["m2"] - l1 = p["l1"] - l2 = p["l2"] - g = p["g"] - - d = g * m1 * l1 + g * m2 * l1 - e = g * m2 * l2 - - return np.array( - [-d * np.sin(x[0]) - e * np.sin(x[0] + x[1]), e * np.sin(x[0] + x[1])] - ) - - -def p_map(x, p): - """ - Dynamics function of your system - Note that the control input is included in the parameter, - and needs to be unpacked. - """ - if check_failure(x, p): - return x, True - - M = mass_matrix(x, p) - C = coriolis(x, p) - G = gravitational(x, p) - - tau = np.array([0, p["torque"]]) - - # TODO check dimensions - x += p["t_step"] * np.stack(x[0:2], np.linalg.solve(M, tau - C * x[2:] - G)) - - return x, check_failure(x, p) - - -def check_failure(x, p): - """ - Check if a state is in the failure set. - """ - elbow_height = p["l1"] * np.cos(x[0]) - end_eff_height = p["l1"] * np.cos(x[0]) + p["l2"] * np.cos(x[1] - x[0]) - if elbow_height <= 0.0: - return True - elif end_eff_height <= 0.0: - return True - else: - return False - - -# Viability functions -def sa2xp(state_action, x, p): - x = np.atleast_1d(state_action[: p["n_states"]]) - p["torque"] = np.atleast_1d(state_action[p["n_states"] :]) - p["torque"] = np.min([p["u_upper_bound"], p["torque"]]) # bound torque - p["torque"] = np.max([p["u_lower_bound"], p["torque"]]) - return x, p - - -def xp2s(x, p): - return x diff --git a/models/ardyn.py b/models/ardyn.py deleted file mode 100644 index 86337ac..0000000 --- a/models/ardyn.py +++ /dev/null @@ -1,47 +0,0 @@ -import numpy as np - -""" -Arbitrary dynamics -x_{k+1} = map(x_k, p) -p: dict of parameters. For convenience, actions are also stored here. -""" - -# map: x_k+1, failed = map - - -def p_map(x, p): - """ - Dynamics function of your system - Note that the control input is included in the parameter, - and needs to be unpacked. - """ - if check_failure(x, p): - return x, True - - x += np.minimum(1, np.linalg.norm(x)) * (p["nonlinear"](x, p)) + p["actions"] - - return x, check_failure(x, p) - - -def check_failure(x, p): - """ - Check if a state is in the failure set. Pass in a tuple of indices for which - failure conditions to check. Currently: 0 for falling, 1 for direction rev. - """ - if np.linalg.norm(x) <= p["fail_bound"]: - return False - else: - return True - - -# Viability functions - - -def sa2xp(state_action, x, p): - x = np.atleast_2d(state_action[: p["n_states"]]) - p["actions"] = np.atleast_2d(state_action[p["n_states"] :]) - return x, p - - -def xp2s(x, p): - return x diff --git a/models/nslip.py b/models/nslip.py index 838b7f2..3400d4c 100644 --- a/models/nslip.py +++ b/models/nslip.py @@ -108,15 +108,10 @@ def stance_dynamics2(t, x): def stance_dynamics(t, x): # since legs are massless, the orientation of the knee doesn't matter. alpha = np.arctan2(x[1] - x[5], x[0] - x[4]) - np.pi / 2.0 - leg_length = np.hypot(x[0] - x[4], x[1] - x[5]) - # if np.greater_equal((UPPER_LEG**2+LOWER_LEG**2 - leg_length**2) - # /(2*UPPER_LEG*LOWER_LEG), 1): - # print("warning") + leg_length = min(UPPER_LEG + LOWER_LEG, np.hypot(x[0] - x[4], x[1] - x[5])) beta = np.arccos( (UPPER_LEG**2 + LOWER_LEG**2 - leg_length**2) / (2 * UPPER_LEG * LOWER_LEG) ) - # if np.isnan(beta): #TODO test for minimum value... - # print("HELLO!") # sinbeta = max(np.sin(beta), 1e-5) tau = STIFFNESS * (RESTING_ANGLE - beta) leg_force = leg_length / (UPPER_LEG * LOWER_LEG) * tau / np.sin(beta) @@ -321,7 +316,7 @@ def s2x(x, p, s): # check that we are at apex assert np.isclose(x[3], 0), "state x: " + str(x) + " and e: " + str(s) - x_new = x + x_new = np.array(x, copy=True) x_new[1] = p["total_energy"] * s / p["mass"] / p["gravity"] x_new[2] = np.sqrt(p["total_energy"] * (1 - s) / p["mass"] * 2) x_new[3] = 0.0 # shouldn't be necessary, but avoids errors accumulating @@ -336,4 +331,6 @@ def sa2xp(state_action, p): p_new = p.copy() p_new["angle_of_attack"] = state_action[1] x = s2x(p_new["x0"], p_new, state_action[0].copy()) - return x, p_new + if isinstance(p_new["x0"], np.ndarray): + p_new["x0"] = p_new["x0"].copy() + return x.copy(), p_new diff --git a/models/slip.py b/models/slip.py index 287c9e8..dbfa836 100644 --- a/models/slip.py +++ b/models/slip.py @@ -1,9 +1,196 @@ import numpy as np import scipy.integrate as integrate -# from numba import jit +from numba import njit -def feasible(x, p): +@njit(cache=True) +def _compute_flight_dynamics(x, gravity): + return np.array( + [x[2], x[3], 0.0, -gravity, x[2], x[3], 0.0], + dtype=np.float64, + ) + + +@njit(cache=True) +def _compute_stance_dynamics( + x, + gravity, + mass, + stiffness, + resting_length, + leg_length_offset, +): + x_rel = x[0] - x[4] + y_rel = x[1] - x[5] + alpha = np.arctan2(y_rel, x_rel) - np.pi / 2.0 + spring_length = np.hypot(x_rel, y_rel) - leg_length_offset + leg_force = stiffness / mass * (resting_length - spring_length) + xdotdot = -leg_force * np.sin(alpha) + ydotdot = leg_force * np.cos(alpha) - gravity + result = np.empty(7, dtype=np.float64) + result[0] = x[2] + result[1] = x[3] + result[2] = xdotdot + result[3] = ydotdot + result[4] = 0.0 + result[5] = 0.0 + result[6] = 0.0 + return result + + +@njit(cache=True) +def _evaluate_fall_event(x): + return x[1] + + +@njit(cache=True) +def _evaluate_touchdown_event(x): + return x[5] - x[6] + + +@njit(cache=True) +def _evaluate_liftoff_event(x, resting_length, leg_length_offset): + x_rel = x[0] - x[4] + y_rel = x[1] - x[5] + spring_length = np.hypot(x_rel, y_rel) - leg_length_offset + return spring_length - resting_length + + +@njit(cache=True) +def _evaluate_reversal_event(x): + return x[2] + 1e-5 + + +@njit(cache=True) +def _evaluate_apex_event(x): + return x[3] + + +class _SimpleSolution: + def __init__(self, t, y, t_events, y_events, message="Analytic flight segment"): + self.t = t + self.y = y + self.t_events = t_events + self.y_events = y_events + self.status = 0 + self.message = message + self.success = True + + +def _solve_ballistic_quadratic(offset, velocity, gravity): + """ + Solve offset + velocity * dt - 0.5 * gravity * dt^2 = 0 for the smallest non-negative dt. + Returns None if there is no non-negative real solution within numerical tolerance. + """ + a = -0.5 * gravity + b = velocity + c = offset + discriminant = b * b - 4 * a * c + if discriminant < 0: + if discriminant > -1e-12: + discriminant = 0.0 + else: + return None + sqrt_disc = np.sqrt(discriminant) + denom = 2 * a + roots = [] + for root in ((-b + sqrt_disc) / denom, (-b - sqrt_disc) / denom): + if root >= 0: + roots.append(max(0.0, root)) + if not roots: + return None + return min(roots) + + +def _compute_apex_time(velocity, gravity): + if velocity < 0: + return None + if np.isclose(velocity, 0.0): + return 0.0 + return velocity / gravity + + +def _compute_flight_states(x0, vx0, vy0, gravity, deltas): + x_positions = x0[0] + vx0 * deltas + y_positions = x0[1] + vy0 * deltas - 0.5 * gravity * deltas**2 + vx = np.full_like(deltas, vx0, dtype=float) + vy = vy0 - gravity * deltas + foot_x = x0[4] + vx0 * deltas + foot_y = x0[5] + vy0 * deltas - 0.5 * gravity * deltas**2 + ground = np.full_like(deltas, x0[-1], dtype=float) + return np.vstack((x_positions, y_positions, vx, vy, foot_x, foot_y, ground)) + + +def _integrate_flight_analytically(x0, p, t_start, max_time, event_types, max_step): + gravity = p["gravity"] + x0 = np.asarray(x0, dtype=float) + vx0 = x0[2] + vy0 = x0[3] + if max_step <= 0: + raise ValueError("max_step must be positive for analytic flight integration.") + + event_times = [None] * len(event_types) + for idx, event_name in enumerate(event_types): + if event_name == "fall": + event_times[idx] = _solve_ballistic_quadratic(x0[1], vy0, gravity) + elif event_name == "touchdown": + foot_offset = x0[5] - x0[-1] + event_times[idx] = _solve_ballistic_quadratic(foot_offset, vy0, gravity) + elif event_name == "apex": + event_times[idx] = _compute_apex_time(vy0, gravity) + + event_hit_idx = None + dt_event = None + for idx, candidate in enumerate(event_times): + if candidate is None: + continue + if candidate > max_time: + continue + if dt_event is None or candidate < dt_event: + dt_event = candidate + event_hit_idx = idx + + if dt_event is None: + duration = max_time + event_hit_idx = None + else: + duration = dt_event + + if duration <= 0: + t_samples = np.array([t_start], dtype=float) + y_samples = x0.reshape(-1, 1) + else: + steps = max(int(np.ceil(duration / max_step)), 1) + deltas = np.linspace(0.0, duration, steps + 1) + t_samples = t_start + deltas + y_samples = _compute_flight_states(x0, vx0, vy0, gravity, deltas) + + t_events = [np.array([], dtype=float) for _ in event_types] + y_events = [np.zeros((0, x0.size)) for _ in event_types] + if event_hit_idx is not None: + event_time = t_start + duration + t_events[event_hit_idx] = np.array([event_time], dtype=float) + y_events[event_hit_idx] = y_samples[:, -1].reshape(1, -1) + + return _SimpleSolution( + t=t_samples, + y=y_samples, + t_events=t_events, + y_events=y_events, + ) + + +def _integrate_flight_numerically(dynamics, state, t_start, max_time, events, max_step): + return integrate.solve_ivp( + dynamics, + t_span=[t_start, t_start + max_time], + y0=state, + events=events, + max_step=max_step, + ) + + +def is_feasible(x, p): """ check if state is at all feasible (body/foot underground) returns a boolean @@ -14,38 +201,15 @@ def feasible(x, p): def p_map(x, p): - """ - Wrapper function for step function, returning only x_next, and -1 if failed - Essentially, the Poincare map. - """ - if type(p) is dict: - if not feasible(x, p): - return x, True # return failed if foot starts underground - sol = step(x, p) - # if len(sol.t_events) < 7: - # # print(len(sol.t_events)) - # return sol.y[:, -1], True - return sol.y[:, -1], check_failure(sol.y[:, -1]) - elif type(p) is tuple: - vector_of_x = np.zeros(x.shape) # initialize result array - vector_of_fail = np.zeros(x.shape[1]) - # TODO: for shorthand, allow just a single tuple to be passed in - # this can be done easily with itertools - for idx, p0 in enumerate(p): - if not feasible(x, p): - vector_of_x[:, idx] = x[:, idx] - vector_of_fail[idx] = True - else: - sol = step(x[:, idx], p0) # p0 = p[idx] - vector_of_x[:, idx] = sol.y[:, -1] - vector_of_fail[idx] = check_failure(sol.y[:, -1]) - return vector_of_x, vector_of_fail - else: - print("WARNING: I got a parameter type that I don't understand.") - return (x, True) + x_state = np.asarray(x, dtype=float) + if not is_feasible(x_state, p): + return x_state, True + sol = step(x_state, p) + x_next = sol.y[:, -1] + return x_next, check_failure(x_next) -def step(x0, p, prev_sol=None): +def step(x0, p, prev_sol=None, flight_mode=None): """ Take one step from apex to apex/failure. returns a sol object from integrate.solve_ivp, with all phases @@ -63,92 +227,100 @@ def step(x0, p, prev_sol=None): # @jit(nopython=True) def flight_dynamics(t, x): - # code in flight dynamics, xdot_ = f() - return np.array([x[2], x[3], 0, -GRAVITY, x[2], x[3], 0]) + return _compute_flight_dynamics(x, GRAVITY) # @jit(nopython=True) def stance_dynamics(t, x): # stance dynamics - alpha = np.arctan2(x[1] - x[5], x[0] - x[4]) - np.pi / 2.0 - spring_length = np.hypot(x[0] - x[4], x[1] - x[5]) - LEG_LENGTH_OFFSET - leg_force = STIFFNESS / MASS * (RESTING_LENGTH - spring_length) - xdotdot = -leg_force * np.sin(alpha) - ydotdot = leg_force * np.cos(alpha) - GRAVITY - return np.array([x[2], x[3], xdotdot, ydotdot, 0, 0, 0]) + return _compute_stance_dynamics( + x, + GRAVITY, + MASS, + STIFFNESS, + RESTING_LENGTH, + LEG_LENGTH_OFFSET, + ) # @jit(nopython=True) def fall_event(t, x): """ Event function to detect the body hitting the floor (failure) """ - return x[1] + return float(_evaluate_fall_event(x)) fall_event.terminal = True fall_event.direction = -1 # @jit(nopython=True) def touchdown_event(t, x): - """ - Event function for foot touchdown (transition to stance) - """ - # x[1]- np.cos(p['angle_of_attack'])*RESTING_LENGTH - # (which is = x[5]) - return x[5] - x[-1] + return float(_evaluate_touchdown_event(x)) touchdown_event.terminal = True # no longer actually necessary... touchdown_event.direction = -1 # @jit(nopython=True) def liftoff_event(t, x): - """ - Event function to reach maximum spring extension (transition to flight) - """ - spring_length = ( - np.hypot(x[0] - x[4], x[1] - x[5]) - p["actuator_resting_length"] + return float( + _evaluate_liftoff_event( + x, + RESTING_LENGTH, + p["actuator_resting_length"], + ) ) - return spring_length - RESTING_LENGTH - # ((x[0]-x[4])**2 + (x[1]-x[5])**2) - RESTING_LENGTH**2 liftoff_event.terminal = True liftoff_event.direction = 1 # @jit(nopython=True) def apex_event(t, x): - """ - Event function to reach apex - """ - return x[3] + return float(_evaluate_apex_event(x)) apex_event.terminal = True # @jit(nopython=True) def reversal_event(t, x): - """ - Event function for direction reversal - """ - return x[2] + 1e-5 # for numerics, allow for "straight up" + return float(_evaluate_reversal_event(x)) # allow for "straight up" reversal_event.terminal = True reversal_event.direction = -1 # * Start of step code * # - # TODO: properly update sol object with all info, not just the trajectories - if prev_sol is not None: t0 = prev_sol.t[-1] else: t0 = 0 # starting time + selected_flight_mode = flight_mode or p.get("flight_solver", "analytic") + + def run_flight_segment(state, event_sequence, start_time): + if selected_flight_mode == "numeric": + event_lookup = { + "fall": fall_event, + "touchdown": touchdown_event, + "apex": apex_event, + } + events = [event_lookup[name] for name in event_sequence] + return _integrate_flight_numerically( + flight_dynamics, + state, + start_time, + MAX_TIME, + events, + max_step=0.01, + ) + return _integrate_flight_analytically( + state, + p, + t_start=start_time, + max_time=MAX_TIME, + event_types=event_sequence, + max_step=0.01, + ) + # * FLIGHT: simulate till touchdown - events = [fall_event, touchdown_event] - sol = integrate.solve_ivp( - flight_dynamics, t_span=[t0, t0 + MAX_TIME], y0=x0, events=events, max_step=0.01 - ) + sol = run_flight_segment(x0, ("fall", "touchdown"), t0) - # TODO Put each part of the step into a list, so you can concat them - # TODO programmatically, and reduce code length. - # if you fell, stop now if sol.t_events[0].size != 0: # if empty if prev_sol is not None: sol.t = np.concatenate((prev_sol.t, sol.t)) @@ -180,16 +352,8 @@ def reversal_event(t, x): return sol # * FLIGHT: simulate till apex - events = [fall_event, apex_event] - x0 = reset_leg(sol2.y[:, -1], p) - sol3 = integrate.solve_ivp( - flight_dynamics, - t_span=[sol2.t[-1], sol2.t[-1] + MAX_TIME], - y0=x0, - events=events, - max_step=0.01, - ) + sol3 = run_flight_segment(x0, ("fall", "apex"), sol2.t[-1]) # concatenate all solutions sol.t = np.concatenate((sol.t, sol2.t, sol3.t)) @@ -227,10 +391,11 @@ def check_failure(x, fail_idx=(0, 1, 2)): def reset_leg(x, p): + x_new = np.array(x, copy=True) leg_length = p["resting_length"] + p["actuator_resting_length"] - x[4] = x[0] + np.sin(p["angle_of_attack"]) * leg_length - x[5] = x[1] - np.cos(p["angle_of_attack"]) * leg_length - return x + x_new[4] = x_new[0] + np.sin(p["angle_of_attack"]) * leg_length + x_new[5] = x_new[1] - np.cos(p["angle_of_attack"]) * leg_length + return x_new def compute_spring_length(x, p): @@ -297,18 +462,18 @@ def find_limit_cycle(x, p, options): # Use the bisection method to get a good initial guess for key # Initial solution - reset_leg(x, p) + x = reset_leg(x, p) # * check for feasibility # Somewhat hacky, very specific to AoA - if not feasible(x, p): + if not is_feasible(x, p): # if we're searching for AOAs, start from a feasible one if options["parameter_name"] == "angle_of_attack": # starting infeasible # assuming we're looking for an aoa for aoa in np.linspace(p["angle_of_attack"], np.pi / 2, 9): p["angle_of_attack"] = aoa - reset_leg(x, p) - if feasible(x, p): + x = reset_leg(x, p) + if is_feasible(x, p): break else: return aoa, limit_cycle_found @@ -459,12 +624,11 @@ def s2x(x, p, s): # check that we are at apex assert np.isclose(x[3], 0), "state x: " + str(x) + " and e: " + str(s) - x_new = p["x0"] + x_new = np.array(p["x0"], copy=True) x_new[1] = p["total_energy"] * s / p["mass"] / p["gravity"] x_new[2] = np.sqrt(p["total_energy"] * (1 - s) / p["mass"] * 2) x_new[3] = 0.0 # shouldn't be necessary, but avoids errors accumulating - x = reset_leg(x, p) - return x_new + return reset_leg(x_new, p) def sa2xp(state_action, p): @@ -472,6 +636,9 @@ def sa2xp(state_action, p): Specifically map state_actions to x and p """ p_new = p.copy() + if "x0" in p_new: + p_new["x0"] = np.array(p_new["x0"], copy=True) p_new["angle_of_attack"] = state_action[1] - x = s2x(p_new["x0"], p_new, state_action[0]).copy() + x_reference = p_new.get("x0", None) + x = s2x(x_reference, p_new, state_action[0]) return x, p_new diff --git a/pyproject.toml b/pyproject.toml index 8ba0216..5c83844 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "vibly" version = "0.2" description = "vibly" readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.10" authors = [ { name = "Steve Heim", email = "heim.steve@gmail.com" } ] @@ -17,6 +17,7 @@ dependencies = [ "joblib>=1.4.2", "matplotlib>=3.0.3", "numpy>=1.18.4", + "numba>=0.60.0", "ruff>=0.9.7", "scipy>=1.5.1", "ttictoc>=0.5.6" @@ -79,5 +80,13 @@ skip-magic-trailing-comma = false # Like Black, automatically detect the appropriate line ending. line-ending = "auto" -[tool.uv.workspace] -members = ["demos"] +[dependency-groups] +dev = [ + "pytest>=8.3.5", + "snakeviz>=2.2.1", +] + +[tool.pytest.ini_options] +markers = [ + "slow: mark tests that perform full viability recomputation.", +] diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..65c9d1b --- /dev/null +++ b/tests/README.md @@ -0,0 +1,23 @@ +# Test Suite Overview + +This project keeps tests lean and split by scope so you can run the right level of coverage for the change you’re making. + +## Layout + +- `tests/unit/`: fast, in-process checks for individual helpers. They don’t touch the heavy dynamical models. Run these constantly while editing core utilities. +- `tests/integration/`: higher-level checks that exercise data pipelines and end-to-end behaviours. We split these into: + - `test_reference_outputs.py`: quick sanity checks against frozen `.npz` fixtures. + - `test_viability_regression.py`: marked with `@pytest.mark.slow`; recomputes viability data and control policies to catch regressions. These tests rely on the pre-generated fixtures under `tests/integration/data/`. + +## Running + +- Fast loop: `uv run pytest -k "not slow"` (or `pytest` if you’re outside uv). This runs unit tests and quick integration checks. +- Full regression: `uv run pytest -m slow`. Use this before large refactors or when touching models/control logic; they take minutes. +- Targeted files: append paths/keywords, e.g. `uv run pytest tests/unit/test_viability_utils.py`. + +## Philosophy + +- Prefer deterministic fixtures (`.npz`) for expensive computations. Regenerate them only when behaviour intentionally changes, and note the rationale in the PR/commit. +- Avoid duplicating logic inside tests; import the public functions and assert on shapes, types, or exact matches. +- Keep slow tests clearly marked and grouped so contributors (and future automation) can choose between fast feedback and deeper validation. +- When adding new tests, decide early whether they belong in the fast or slow bucket, document any new fixtures, and update this README if the workflow changes. diff --git a/tests/integration/data/closed_satellite11.npz b/tests/integration/data/closed_satellite11.npz new file mode 100644 index 0000000..627109f Binary files /dev/null and b/tests/integration/data/closed_satellite11.npz differ diff --git a/tests/integration/data/hover_map.npz b/tests/integration/data/hover_map.npz new file mode 100644 index 0000000..92e0a8b Binary files /dev/null and b/tests/integration/data/hover_map.npz differ diff --git a/tests/integration/data/satellite_q_value.npz b/tests/integration/data/satellite_q_value.npz new file mode 100644 index 0000000..3078762 Binary files /dev/null and b/tests/integration/data/satellite_q_value.npz differ diff --git a/tests/integration/data/slip_demo_sol.npz b/tests/integration/data/slip_demo_sol.npz new file mode 100644 index 0000000..34699c4 Binary files /dev/null and b/tests/integration/data/slip_demo_sol.npz differ diff --git a/tests/integration/data/slip_map.npz b/tests/integration/data/slip_map.npz new file mode 100644 index 0000000..1e78496 Binary files /dev/null and b/tests/integration/data/slip_map.npz differ diff --git a/tests/integration/test_reference_outputs.py b/tests/integration/test_reference_outputs.py new file mode 100644 index 0000000..d7f3c60 --- /dev/null +++ b/tests/integration/test_reference_outputs.py @@ -0,0 +1,71 @@ +from pathlib import Path + +import numpy as np + + +FIXTURE_DIR = Path(__file__).resolve().parent / "data" + + +def test_slip_demo_reference_solution(): + path = FIXTURE_DIR / "slip_demo_sol.npz" + assert path.exists(), "Expected slip demo reference output to be present." + + data = np.load(path) + t = data["t"] + y = data["y"] + + diffs = np.diff(t) + assert np.all(diffs >= 0), "Solution time steps should be non-decreasing." + assert np.any(diffs > 0), "Solution should contain at least one positive timestep." + assert y.shape == (7, t.size) + + +def test_hover_map_reference_contents(): + path = FIXTURE_DIR / "hover_map.npz" + assert path.exists(), "Expected hover_map.npz to be present." + + data = np.load(path) + q_map = data["Q_map"] + q_fail = data["Q_F"] + q_viable = data["Q_V"] + q_measure = data["Q_M"] + s_measure = data["S_M"] + s_grid = data["s_grid"] + a_grid = data["a_grid"] + + assert q_map.shape == (201, 161) + assert q_map.dtype == np.int64 + assert q_fail.shape == q_map.shape + assert q_viable.dtype == bool + assert q_measure.shape == q_map.shape + assert s_measure.shape == (201,) + assert np.isclose(a_grid.max(), 0.8) + assert np.isclose(s_grid.min(), 0.0) + assert np.any(q_fail) + assert np.any(~q_fail) + + +def test_slip_map_reference_contents(): + path = FIXTURE_DIR / "slip_map.npz" + assert path.exists(), "Expected slip_map.npz to be present." + + data = np.load(path) + q_map = data["Q_map"] + q_fail = data["Q_F"] + q_viable = data["Q_V"] + q_measure = data["Q_M"] + s_measure = data["S_M"] + s_grid = data["s_grid"] + a_grid = data["a_grid"] + + assert q_map.shape == (180, 161) + assert q_map.dtype == np.int64 + assert q_fail.shape == q_map.shape + assert q_viable.dtype == bool + assert q_measure.shape == q_map.shape + assert s_measure.shape == (180,) + assert np.isclose(s_grid[0], 0.1) + assert np.isclose(a_grid[0], -10 / 180 * np.pi) + assert np.all(q_measure[q_viable] >= 0.0) + assert np.any(q_measure[q_viable] > 0.0) + diff --git a/tests/integration/test_viability_regression.py b/tests/integration/test_viability_regression.py new file mode 100644 index 0000000..01fb396 --- /dev/null +++ b/tests/integration/test_viability_regression.py @@ -0,0 +1,217 @@ +from pathlib import Path +import warnings +import multiprocessing as mp + +import numpy as np +import pytest + +import models.hovership as hovership +import models.slip as slip +from models.hovership import p_map as hovership_p_map +import models.spaceship4 as spaceship4 + +import viability as vibly +import control + + +FIXTURE_DIR = Path(__file__).resolve().parent / "data" + +if mp.get_start_method(allow_none=True) != "fork": + try: + mp.set_start_method("fork") + except RuntimeError: + pass + + +@pytest.mark.slow +def test_hovership_viability_matches_reference(): + reference = np.load(FIXTURE_DIR / "hover_map.npz") + + p = { + "n_states": 1, + "base_gravity": 0.1, + "gravity": 1, + "thrust": 0, + "max_thrust": 0.8, + "ceiling": 2, + "control_frequency": 1, + } + x0 = np.array([0.5]) + + s_grid = (np.linspace(-0.0, p["ceiling"], 201),) + a_grid = (np.linspace(0.0, p["max_thrust"], 161),) + grids = {"states": s_grid, "actions": a_grid} + + p_map = hovership_p_map + p_map.p = p + p_map.x = x0 + p_map.sa2xp = hovership.sa2xp + p_map.xp2s = hovership.xp2s + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message="Conversion of an array with ndim > 0 to a scalar is deprecated", + ) + result = vibly.compute_Q_map(grids, p_map, check_grid=True, parallel=False) + q_map = result.q_map + q_fail = result.q_fail + q_on_grid = result.q_on_grid + q_v, s_v = vibly.compute_QV(q_map, grids, Q_V=~q_fail, Q_on_grid=q_on_grid) + s_m = vibly.project_Q2S(q_v, grids, proj_opt=np.mean) + q_m = vibly.map_S2Q(q_map, s_m, s_grid, Q_V=q_v, Q_on_grid=q_on_grid) + + assert np.array_equal(q_map, reference["Q_map"]) + assert np.array_equal(q_fail, reference["Q_F"]) + assert np.array_equal(q_v, reference["Q_V"]) + assert np.allclose(s_m, reference["S_M"]) + assert np.allclose(q_m, reference["Q_M"]) + + +@pytest.mark.slow +def test_slip_viability_matches_reference(): + reference = np.load(FIXTURE_DIR / "slip_map.npz") + + p = { + "mass": 80.0, + "stiffness": 8200.0, + "resting_length": 1.0, + "gravity": 9.81, + "angle_of_attack": 1 / 5 * np.pi, + "actuator_resting_length": 0, + } + x0 = np.array([0, 0.85, 5.5, 0, 0, 0, 0], dtype=float) + x0 = slip.reset_leg(x0, p) + p["x0"] = x0 + p["total_energy"] = slip.compute_total_energy(x0, p) + + s_grid = np.linspace(0.1, 1, 181) + s_grid = (s_grid[:-1],) + a_grid = (np.linspace(-10 / 180 * np.pi, 70 / 180 * np.pi, 161),) + grids = {"states": s_grid, "actions": a_grid} + + p_map = slip.p_map + p_map.p = p + p_map.x = x0 + p_map.sa2xp = slip.sa2xp + p_map.xp2s = slip.xp2s + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message="Conversion of an array with ndim > 0 to a scalar is deprecated", + ) + result = vibly.compute_Q_map(grids, p_map) + q_map = result.q_map + q_fail = result.q_fail + q_v, s_v = vibly.compute_QV(q_map, grids, Q_V=~q_fail) + s_m = vibly.project_Q2S(q_v, grids, proj_opt=np.mean) + q_m = vibly.map_S2Q(q_map, s_m, grids["states"], Q_V=q_v) + + assert np.array_equal(q_map, reference["Q_map"]) + assert np.array_equal(q_fail, reference["Q_F"]) + assert np.array_equal(q_v, reference["Q_V"]) + assert np.allclose(s_m, reference["S_M"]) + assert np.allclose(q_m, reference["Q_M"]) + + +@pytest.mark.slow +def test_satellite_parcompute_matches_reference(): + fixture = np.load(FIXTURE_DIR / "closed_satellite11.npz") + s_grid = (fixture["s_grid_0"], fixture["s_grid_1"]) + a_grid = (fixture["a_grid"],) + grids = {"states": s_grid, "actions": a_grid} + + p = { + "n_states": 2, + "geocentric_constant": 10.0, + "geocentric_radius": 10.0, + "angular_speed": 0.1, + "mass": 1.0, + "control_frequency": 1, + "thrust": 1.0, + "radius": 1.0, + "radio_range": 15.0, + } + x0 = fixture["x0"] + + import models.satellite as satellite + from models.satellite import p_map as satellite_p_map + + p_map = satellite_p_map + p_map.p = p + p_map.x = x0 + p_map.sa2xp = satellite.sa2xp + p_map.xp2s = satellite.xp2s + + result = vibly.compute_Q_map(grids, p_map, keep_coords=True, bin_mode="nearest") + + q_map = result.q_map + q_fail = result.q_fail + + assert np.array_equal(q_map, fixture["Q_map"]) + assert np.array_equal(q_fail, fixture["Q_F"]) + + +@pytest.mark.slow +def test_satellite_value_iteration_matches_reference(): + fixture = np.load(FIXTURE_DIR / "closed_satellite11.npz") + q_map = fixture["Q_map"] + s_grid = (fixture["s_grid_0"], fixture["s_grid_1"]) + a_grid = (fixture["a_grid"],) + grids = {"states": s_grid, "actions": a_grid} + q_on_grid = np.ones(q_map.shape, dtype=bool) + + reference = np.load(FIXTURE_DIR / "satellite_q_value.npz") + expected_q_value = reference["Q_value"] + expected_r_value = reference["R_value"] + failure_penalty = float(reference["failure_penalty"]) + + radius = float(fixture["p_radius"]) + radio_range = float(fixture["p_radio_range"]) + + def parsimonious_reward(s, a): + if np.allclose(s, [10.0, 0.0]): + return 1.0 + return 0.0 + + def make_penalty(penalty_scale): + def penalty(s, a): + if s[0] >= radio_range or s[0] <= radius: + return -penalty_scale + return 0.0 + + return penalty + + reward_functions = (parsimonious_reward, make_penalty(failure_penalty)) + + # Warm-start with zero-penalty run as in VIS_basic.py + zero_penalty_reward = (parsimonious_reward, make_penalty(0.0)) + q_value, _ = control.Q_value_iteration( + q_map, + grids, + zero_penalty_reward, + 0.6, + Q_on_grid=q_on_grid, + stopping_threshold=1e-6, + max_iter=1000, + output_R=True, + Q_values=None, + ) + + q_value, r_value = control.Q_value_iteration( + q_map, + grids, + reward_functions, + 0.6, + Q_on_grid=q_on_grid, + stopping_threshold=1e-6, + max_iter=1000, + output_R=True, + Q_values=q_value, + ) + + assert np.allclose(q_value, expected_q_value) + assert np.allclose(r_value, expected_r_value) diff --git a/tests/unit/test_slip_flight.py b/tests/unit/test_slip_flight.py new file mode 100644 index 0000000..fb97ca9 --- /dev/null +++ b/tests/unit/test_slip_flight.py @@ -0,0 +1,37 @@ +import numpy as np + +from models import slip + + +def _build_default_params(): + params = { + "mass": 80.0, + "stiffness": 8200.0, + "resting_length": 1.0, + "gravity": 9.81, + "angle_of_attack": 1 / 5 * np.pi, + "actuator_resting_length": 0.0, + } + x0 = np.array([0, 0.85, 5.5, 0, 0, 0, 0], dtype=float) + x0 = slip.reset_leg(x0, params) + params["x0"] = x0 + params["total_energy"] = slip.compute_total_energy(x0, params) + return x0, params + + +def test_step_analytic_matches_numeric_flight(): + x0, params = _build_default_params() + + analytic = slip.step(x0, params) + numeric = slip.step(x0, params, flight_mode="numeric") + + np.testing.assert_allclose(analytic.y[:, -1], numeric.y[:, -1], atol=1e-9) + + assert len(analytic.t_events) == len(numeric.t_events) + for idx, (t_a, t_b) in enumerate(zip(analytic.t_events, numeric.t_events)): + np.testing.assert_allclose( + t_a, + t_b, + atol=1e-9, + err_msg=f"event index {idx} diverged", + ) diff --git a/tests/unit/test_viability_utils.py b/tests/unit/test_viability_utils.py new file mode 100644 index 0000000..774df29 --- /dev/null +++ b/tests/unit/test_viability_utils.py @@ -0,0 +1,190 @@ +import numpy as np + +from viability import viability as vibly + + +def test_digitize_s_handles_bins_and_coordinates(): + s_grid = (np.array([0.0, 0.5, 1.0]), np.array([-1.0, 0.0, 1.0])) + + s = np.array([0.5, 0.0]) + bin_indices = vibly.digitize_s(s, s_grid) + assert tuple(bin_indices) == (2, 2) + + ravelled = vibly.digitize_s(s, s_grid, shape=(3, 3)) + assert ravelled == np.ravel_multi_index((2, 2), (3, 3)) + + coord_indices = vibly.digitize_s(s, s_grid, to_bin=False) + assert tuple(coord_indices) == (1, 1) + + +def test_get_grid_indices_returns_enclosing_vertices(): + s_grid = (np.arange(3), np.arange(4)) + + neighbors = vibly.get_grid_indices((1, 2), s_grid) + assert set(neighbors) == {(0, 1), (0, 2), (1, 1), (1, 2)} + + # Edge bins should only report in-bounds vertices + edge_neighbors = vibly.get_grid_indices((0, 0), s_grid) + assert edge_neighbors == [(0, 0)] + + +def test_project_Q2S_respects_projection_operator(): + grids = { + "states": (np.array([0, 1]), np.array([0, 1])), + "actions": (np.array([0]),), + } + q_values = np.array([[[True], [False]], [[False], [True]]]) + + default_projection = vibly.project_Q2S(q_values, grids) + assert default_projection.dtype == bool + assert np.array_equal(default_projection, np.array([[True, False], [False, True]])) + + summed = vibly.project_Q2S(q_values.astype(int), grids, proj_opt=np.sum) + assert np.array_equal(summed, np.array([[1, 0], [0, 1]])) + + +class DummyMap: + def __init__(self): + self.p = {} + + def sa2xp(self, state_action, params): + state, action = state_action + return np.array([state, action]), params + + def xp2s(self, x_next, params): + return np.array([x_next[0]]) + + def __call__(self, x, params): + state, action = x + next_state = state + action + failed = (next_state < 0.0) or (next_state > 2.0) + return np.array([next_state, action]), failed + + +def test_compute_Q_map_and_compute_QV_simple_system(): + grids = { + "states": (np.array([0.0, 1.0, 2.0]),), + "actions": (np.array([-1.0, 1.0]),), + } + p_map = DummyMap() + + result = vibly.compute_Q_map(grids, p_map, parallel=False) + q_map = result.q_map + q_fail = result.q_fail + assert q_map.shape == (3, 2) + assert q_fail.shape == (3, 2) + + # state 0 with action -1 should fail, state 2 with action +1 should fail + assert q_fail[0, 0] + assert q_fail[2, 1] + assert not q_fail[1, 0] + + q_v, s_v = vibly.compute_QV(q_map, grids, Q_V=~q_fail) + # All states remain viable because each has at least one safe action + assert np.array_equal(s_v.astype(int), np.array([1, 1, 1])) + + # Map a simple state-space measure back into Q + state_measure = np.array([1.0, 0.5, 0.0]) + q_m = vibly.map_S2Q(q_map, state_measure, grids["states"], Q_V=q_v) + assert q_m.shape == q_map.shape + expected_q_m = np.array([[0.0, 0.25], [0.75, 0.0], [0.25, 0.0]]) + assert np.allclose(q_m, expected_q_m) + + +def test_compute_Q_map_with_check_grid_reports_on_grid_hits(): + grids = { + "states": (np.array([0.0, 1.0, 2.0]),), + "actions": (np.array([-1.0, 1.0]),), + } + p_map = DummyMap() + + result = vibly.compute_Q_map( + grids, p_map, check_grid=True, keep_coords=True, parallel=False + ) + q_map = result.q_map + q_fail = result.q_fail + q_on_grid = result.q_on_grid + q_reached = result.q_reached + + assert q_map.shape == (3, 2) + assert q_on_grid.shape == q_map.shape + assert q_on_grid.dtype == bool + assert np.any(q_on_grid) + + total_points = np.prod([grid.size for grid in grids["states"]]) * np.prod( + [grid.size for grid in grids["actions"]] + ) + assert q_reached.shape == (len(grids["states"]), total_points) + + +def test_compute_QV_same_with_and_without_on_grid_flag(): + grids = { + "states": (np.array([0.0, 1.0, 2.0]),), + "actions": (np.array([-1.0, 1.0]),), + } + p_map = DummyMap() + + result = vibly.compute_Q_map( + grids, p_map, check_grid=True, keep_coords=True, parallel=False + ) + q_map = result.q_map + q_fail = result.q_fail + q_on_grid = result.q_on_grid + + q_v_default, s_v_default = vibly.compute_QV(q_map, grids, Q_V=~q_fail) + q_v_on_grid, s_v_on_grid = vibly.compute_QV( + q_map, grids, Q_V=~q_fail, Q_on_grid=q_on_grid + ) + + assert np.array_equal(q_v_default, q_v_on_grid) + assert np.array_equal(s_v_default, s_v_on_grid) + + +def test_is_outside_handles_on_grid_points(): + s_grid = (np.array([0.0, 1.0, 2.0]),) + S_V = np.array([True, False, True]) + + idx_viable = np.ravel_multi_index((0,), (3,)) + assert not vibly.is_outside(idx_viable, s_grid, S_V, on_grid=True) + + idx_outside = np.ravel_multi_index((1,), (3,)) + assert vibly.is_outside(idx_outside, s_grid, S_V, on_grid=True) + + # also exercise the "not already binned" path (currently treats boundary points as outside) + assert vibly.is_outside([0.0], s_grid, S_V, already_binned=False) + assert vibly.is_outside([1.5], s_grid, S_V, already_binned=False) + + +def test_map_S2Q_uses_on_grid_lookup(): + grids = { + "states": (np.array([0.0, 1.0, 2.0]),), + "actions": (np.array([-1.0, 1.0]),), + } + p_map = DummyMap() + + result = vibly.compute_Q_map( + grids, p_map, check_grid=True, keep_coords=True, parallel=False + ) + q_map = result.q_map + q_fail = result.q_fail + q_on_grid = result.q_on_grid + q_v = ~q_fail + + state_measure = np.array([1.0, 0.5, 0.0]) + q_m = vibly.map_S2Q( + q_map, + state_measure, + grids["states"], + Q_V=q_v, + Q_on_grid=q_on_grid, + ) + + assert q_m.shape == q_map.shape + + mask = q_v & q_on_grid + assert np.any(mask) + for idx in zip(*np.where(mask)): + s_idx = np.unravel_index(q_map[idx], state_measure.shape) + neighbors = vibly.get_grid_indices(s_idx, grids["states"]) + expected = np.mean([state_measure[n] for n in neighbors]) + assert np.isclose(q_m[idx], expected) diff --git a/viability/__init__.py b/viability/__init__.py index 4da2981..ae94f94 100644 --- a/viability/__init__.py +++ b/viability/__init__.py @@ -7,6 +7,4 @@ from .viability import get_feasibility_mask from .viability import get_grid_indices from .viability import is_outside -from .viability import parcompute_Q_map -from .viability import parcompute_Q_mapC from .viability import digitize_s diff --git a/viability/viability.py b/viability/viability.py index 06fd4a8..1d9bb98 100644 --- a/viability/viability.py +++ b/viability/viability.py @@ -1,128 +1,22 @@ import itertools as it -import numpy as np import multiprocessing as mp +from dataclasses import dataclass +from typing import Iterable, Iterator, Optional, Sequence, Tuple + +import numpy as np """ Tools for computing the viable set (in state-action space) and viability kernel (in -state space) of a dynamical system. Note, all methods appended with `_2D` are specific -to 2D systems and deprecated. They are left here since they are a lot easier to read, -in case you're interested in understanding what the code is doing. - -For all practical use, refer to the N-D versions. +state space) of a dynamical system in ND. """ -def compute_Q_2D(s_grid, a_grid, p_map): - """Compute the transition map of a system with 1D state and 1D action - NOTES - - s_grid and a_grid have to be iterable lists of lists - e.g. if they have only 1 dimension, they should be `s_grid = ([1, 2], )` - - use p_map to carry parameters - """ - - # create iterators s_grid, a_grid - # TODO: pass in iterators/generators instead - # compute each combination, store result in a huge matrix - - # initialize 1D, reshape later - Q_map = np.zeros((s_grid.size * a_grid.size, 1)) - Q_F = np.zeros((s_grid.size * a_grid.size, 1)) - - # QTransition = Q_map - n = len(s_grid) * len(a_grid) - for idx, state_action in enumerate(it.product(s_grid, a_grid)): - if idx % (n / 10) == 0: - print(".", end=" ") - x, p = p_map.sa2xp(state_action, p_map.p) - x_next, failed = p_map(x, p) - if not failed: - s_next = p_map.xp2s(x_next, p) - # note: Q_map is implicitly already excluding transitions that - # move straight to a failure. While this is not equivalent to the - # algorithm in the paper, for our systems it is a bit faster - Q_map[idx] = s_next - else: - Q_F[idx] = 1 - - return ( - Q_map.reshape((s_grid.size, a_grid.size)), # only 2D - Q_F.reshape((s_grid.size, a_grid.size)), - ) - - -def project_Q2S_2D(Q): - S = np.zeros((Q.shape[0], 1)) - for sdx, val in enumerate(S): - if sum(Q[sdx, :]) > 0: - S[sdx] = 1 - return S - - -def is_outside_2D(s, S_V, s_grid): - """ - given a level set S, check if s is inside S or not - """ - if sum(S_V) <= 1: - return True - - s_min, s_max = s_grid[S_V > 0][[0, -1]] - if s > s_max or s < s_min: - return True - else: - return False - - -def compute_QV_2D(Q_map, grids, Q_V=None): - """Starting from the transition map and set of non-failing state-action - pairs, compute the viable sets. The input Q_V is referred to as Q_N in the - paper when passing it in, but since it is immediately copied to Q_V, we - directly use this naming. - """ - - # Take Q_map as the non-failing set if Q_N is omitted - if Q_V is None: - Q_V = np.copy(Q_map) - Q_V[Q_V > 0] = 1 - - S_old = np.zeros((Q_V.shape[0], 1)) - S_V = project_Q2S_2D(Q_V) - while np.array_equal(S_V, S_old): - for qdx, is_viable in enumerate(np.nditer(Q_V)): # compare w/ np.enum - if is_viable: # only check previously viable (s, a) - if is_outside_2D(Q_map[qdx], S_V, grids["states"]): - Q_V[qdx] = 0 # remove - S_old = S_V - S_V = project_Q2S_2D(Q_V) - - return Q_V, S_V - - -# * Reimplement everything as N-D - - -def get_state_from_ravel(bin_idx, s_grid): - """ - Get state from bin id. Ideally, interpolate - For now, just returning a grid point - """ - bin_idx = np.atleast_1d(bin_idx) - grid_idx = np.zeros(len(s_grid), dtype="int") - s = np.zeros(len(s_grid)) - for dim, grid in enumerate(s_grid): - if bin_idx[dim] >= grid.size: - grid_idx[dim] = grid.size - 1 # upper-est entry - s[dim] = grid[grid_idx[dim]] - else: - grid_idx[dim] = bin_idx[dim] # just put the right-closest grid - s[dim] = grid[grid_idx[dim]] - return s - - -def bin2grid(bin_idx, grids): - """ - To replace `get_state_from_ravel` - receiving a tuple of grids, return the grid- - """ +@dataclass +class TransitionResult: + q_map: np.ndarray + q_fail: np.ndarray + q_on_grid: Optional[np.ndarray] = None + q_reached: Optional[np.ndarray] = None def digitize_s(s, s_grid, shape=None, to_bin=True): @@ -153,86 +47,6 @@ def digitize_s(s, s_grid, shape=None, to_bin=True): return np.ravel_multi_index(s_idx, shape) -def compute_Q_map(grids, p_map, verbose=0, check_grid=False, keep_coords=False): - """Compute the transition map of a system - NOTES - - s_grid and a_grid have to be iterable lists of lists - e.g. if they have only 1 dimension, they should be `s_grid = ([1, 2], )` - - use p_map to carry parameters - - keep_coords: toggle to true to also output an array of actual states - """ - # TODO get rid of check_grid, solve the problem permanently - - # initialize 1D, reshape later - # shape of state-space grid - s_grid_shape = list(map(np.size, grids["states"])) - s_bin_shape = tuple(dim + 1 for dim in s_grid_shape) - a_grid_shape = list(map(np.size, grids["actions"])) - total_gridpoints = np.prod(s_grid_shape) * np.prod(a_grid_shape) - - if verbose > 0: - print("computing a total of " + str(total_gridpoints) + " points.") - - Q_map = np.zeros((total_gridpoints, 1), dtype=int) - Q_F = np.zeros((total_gridpoints, 1), dtype=bool) - if keep_coords: - Q_reached = np.zeros((len(grids["states"]), total_gridpoints)) - - if check_grid: - Q_on_grid = np.copy(Q_F) # HACK: keep track of wether you are in a bin - - for idx, state_action in enumerate( - np.array(list(it.product(*grids["states"], *grids["actions"]))) - ): - if verbose > 1: - # NOTE: requires running python unbuffered (python -u) - if idx % (total_gridpoints / 10) == 0: - print(".", end=" ") - - x, p = p_map.sa2xp(state_action, p_map.p) - - x_next, failed = p_map(x, p) - - s_next = p_map.xp2s(x_next, p) - if keep_coords: - Q_reached[:, idx] = s_next - if not failed: - # note: Q_map is implicitly already excluding transitions that - # move straight to a failure. While this is not equivalent to the - # algorithm in the paper, for our systems it is a bit faster - # bin_idx = np.digitize(state_val, s_grid[state_dx]) - # sbin = np.digitize(s_next, s) - if check_grid: - for sdx, sval in enumerate(np.atleast_1d(s_next)): - if ~np.isin(sval, grids["states"][sdx]): - Q_map[idx] = digitize_s(s_next, grids["states"], s_bin_shape) - break - else: - Q_on_grid[idx] = True - Q_map[idx] = digitize_s( - s_next, grids["states"], s_grid_shape, to_bin=False - ) - else: - # TODO: should actually do this even when failing. - Q_map[idx] = digitize_s(s_next, grids["states"], s_bin_shape) - - # check if s happens to be right on the grid-point - else: - Q_F[idx] = 1 - - Q_map = Q_map.reshape(s_grid_shape + a_grid_shape) - Q_F = Q_F.reshape(s_grid_shape + a_grid_shape) - - deliver = [Q_map, Q_F] - - if check_grid: - deliver.append(Q_on_grid.reshape(s_grid_shape + a_grid_shape)) - if keep_coords: - deliver.append(Q_reached) - - return deliver - - def project_Q2S(Q, grids, proj_opt=None): if proj_opt is None: proj_opt = np.any @@ -305,28 +119,6 @@ def is_outside(s, s_grid, S_V, already_binned=True, on_grid=False): return True return False - # for dim_idx, grid in enumerate(s_grid): - # # if outside the left-most or right-most side of grid, mark as outside - # # * NOTE: this can result in disastrous underestimations if the grid is - # # * not larger than the viable set! - # # TODO: this can lead to understimation if s is right on the gridline - # # because it will still check its neighbors. need to first check. - # if bin_idx[dim_idx] == 0: - # return True - # elif bin_idx[dim_idx] >= grid.size: - # return True - # # Need to redo the loop. In the first loop, we check if any of the points - # # has exited the grid. This needs to be done first to ensure we don't try - # # to index outside the grid - # for dim_idx, grid in enumerate(s_grid): - # # check if enclosing grid points are viable or not - # index_vec = np.zeros(len(s_grid), dtype=int) - # index_vec[dim_idx] = 1 - # if (not S_V[tuple(bin_idx)] or - # not S_V[tuple(bin_idx - index_vec)]): - # return True - - # return False def get_grid_indices(bin_idx, s_grid): @@ -404,227 +196,145 @@ def get_feasibility_mask(feasible, sa2xp, grids, x0, p0): a_shape = list(map(np.size, grids["actions"])) Q_feasible = np.zeros(np.prod(s_shape) * np.prod(a_shape), dtype=bool) - # TODO: can probably simplify this - for idx, state_action in enumerate( - np.array(list(it.product(*grids["states"], *grids["actions"]))) - ): + # state_action_iter = it.product(*grids["states"], *grids["actions"]) + for idx, state_action in enumerate(it.product(*grids["states"], *grids["actions"])): x, p = sa2xp(state_action, p0) Q_feasible[idx] = feasible(x, p) return Q_feasible.reshape(s_shape + a_shape) -# def compute_Q_cont(grids, p_map, verbose=0): -# ''' Compute the transition map of a system, and output the result _without_ -# discretizing into bins, as an array of coordinate vectors (n, m) where -# n is the dimensionality of state, and m are the number of grid-points -# NOTES -# - s_grid and a_grid have to be iterable lists of lists -# e.g. if they have only 1 dimension, they should be `s_grid = ([1, 2], )` -# - use p_map to carry parameters -# ''' - -# # initialize 1D, reshape later -# # shape of state-space grid -# # initialize 1D, reshape later -# # shape of state-space grid -# s_grid_shape = list(map(np.size, grids['states'])) -# a_grid_shape = list(map(np.size, grids['actions'])) -# total_gridpoints = np.prod(s_grid_shape)*np.prod(a_grid_shape) - -# if verbose > 0: -# print('computing a total of ' + str(total_gridpoints) + ' points.') - -# # Q_map = np.zeros((total_gridpoints, 1), dtype=int) -# Q_map = np.zeros((len(grids['states']), total_gridpoints)) -# Q_F = np.zeros((total_gridpoints, 1), dtype=bool) - -# for idx, state_action in enumerate(np.array(list( -# it.product(*grids['states'], *grids['actions'])))): - -# if verbose > 1: -# # NOTE: requires running python unbuffered (python -u) -# if idx % (total_gridpoints/10) == 0: -# print('.', end=' ') - -# x, p = p_map.sa2xp(state_action, p_map.p) -# x_next, failed = p_map(x, p) -# s_next = p_map.xp2s(x_next, p) -# Q_map[:, idx] = s_next -# if failed: -# Q_F[idx] = True - -# return (Q_map, Q_F) - - -def parcompute_Q_map(grids, p_map, verbose=0, check_grid=False, keep_coords=False): - """Compute the transition map of a system in parallel - - s_grid and a_grid have to be iterable lists of lists - e.g. if they have only 1 dimension, they should be `s_grid = ([1, 2], )` - - use p_map to carry parameters - - keep_coords: toggle to true to also output an array of actual states - """ +def _grid_shapes(grids): + s_shape = tuple(map(np.size, grids["states"])) + a_shape = tuple(map(np.size, grids["actions"])) + return s_shape, a_shape - # initialize 1D, reshape later - # shape of state-space grid - s_grid_shape = list(map(np.size, grids["states"])) - s_bin_shape = tuple(dim + 1 for dim in s_grid_shape) - a_grid_shape = list(map(np.size, grids["actions"])) - total_gridpoints = np.prod(s_grid_shape) * np.prod(a_grid_shape) - if verbose > 0: - print("computing a total of " + str(total_gridpoints) + " points.") - # initilize pool - pool = mp.Pool() - # create list of args - SA = list(it.product(*grids["states"], *grids["actions"])) - # for sa in SA: - # print(sa[1],end=' in a, and ') - p = p_map.p.copy() - args = [p_map.sa2xp(sa, p) for sa in SA] - # for ar in args: - # print(ar[1]['angle_of_attack'], end='') - # print(" and " + str(ar[0][1])) - # start pool with starmap - results = pool.starmap(p_map, args) - pool.close() - - # for r in results: - # print(r[0]) - - # do the standard stuff (put into bins etc.) - - Q_map = np.zeros((total_gridpoints, 1), dtype=int) - Q_F = np.zeros((total_gridpoints, 1), dtype=bool) - if keep_coords: - Q_reached = np.zeros((len(grids["states"]), total_gridpoints)) - - if check_grid: - Q_on_grid = np.copy(Q_F) # HACK: keep track of wether you are in a bin - - # TODO: no need to use it.product here, since we just use the index - for idx, sa in enumerate( - np.array(list(it.product(*grids["states"], *grids["actions"]))) - ): - x_next, failed = results[idx] - s_next = p_map.xp2s(x_next, p) - if keep_coords: - Q_reached[:, idx] = s_next - if not failed: - if check_grid: - for sdx, sval in enumerate(np.atleast_1d(s_next)): - if ~np.isin(sval, grids["states"][sdx]): - Q_map[idx] = digitize_s(s_next, grids["states"], s_bin_shape) - break - else: - Q_on_grid[idx] = True - Q_map[idx] = digitize_s( - s_next, grids["states"], s_grid_shape, to_bin=False - ) - else: - Q_map[idx] = digitize_s(s_next, grids["states"], s_bin_shape) +def _total_gridpoints(grids): + s_shape, a_shape = _grid_shapes(grids) + return int(np.prod(s_shape) * np.prod(a_shape)) - # check if s happens to be right on the grid-point - else: - Q_F[idx] = 1 - Q_map = Q_map.reshape(s_grid_shape + a_grid_shape) - Q_F = Q_F.reshape(s_grid_shape + a_grid_shape) +def _assemble_transition( + grids, + records: Iterable[Tuple[np.ndarray, bool]], + total_count: int, + *, + check_grid: bool, + keep_coords: bool, + bin_mode: str, +) -> TransitionResult: + s_grid_shape, a_grid_shape = _grid_shapes(grids) + s_bin_shape = tuple(dim + 1 for dim in s_grid_shape) - deliver = [Q_map, Q_F] + q_map_flat = np.zeros(total_count, dtype=int) + q_fail_flat = np.zeros(total_count, dtype=bool) + q_on_grid_flat = np.zeros(total_count, dtype=bool) if check_grid else None + q_reached = np.zeros((len(grids["states"]), total_count)) if keep_coords else None - if check_grid: - deliver.append(Q_on_grid.reshape(s_grid_shape + a_grid_shape)) - if keep_coords: - deliver.append(Q_reached) + if bin_mode == "nearest": - return deliver + def encode(point): + return digitize_s(point, grids["states"], shape=s_grid_shape, to_bin=False) + encode_on_grid = encode + else: -def parcompute_Q_mapC(grids, p_map, verbose=0, check_grid=False, keep_coords=False): - """Compute the transition map of a system in parallel - - s_grid and a_grid have to be iterable lists of lists - e.g. if they have only 1 dimension, they should be `s_grid = ([1, 2], )` - - use p_map to carry parameters - - keep_coords: toggle to true to also output an array of actual states - """ - # TODO: integrate this into standard parcompute, and clean up. + def encode(point): + return digitize_s(point, grids["states"], s_bin_shape) + + def encode_on_grid(point): + return digitize_s(point, grids["states"], s_grid_shape, to_bin=False) + + for idx, (s_next, failed) in enumerate(records): + s_vec = np.atleast_1d(s_next) + if keep_coords and q_reached is not None: + q_reached[:, idx] = s_vec + if failed: + q_fail_flat[idx] = True + continue + if check_grid and q_on_grid_flat is not None: + for dim_idx, sval in enumerate(s_vec): + if not np.isin(sval, grids["states"][dim_idx]): + q_map_flat[idx] = encode(s_vec) + break + else: + q_on_grid_flat[idx] = True + q_map_flat[idx] = encode_on_grid(s_vec) + else: + q_map_flat[idx] = encode(s_vec) + + q_map = q_map_flat.reshape(s_grid_shape + a_grid_shape) + q_fail = q_fail_flat.reshape(s_grid_shape + a_grid_shape) + q_on_grid = ( + q_on_grid_flat.reshape(s_grid_shape + a_grid_shape) + if check_grid and q_on_grid_flat is not None + else None + ) - # initialize 1D, reshape later - # shape of state-space grid - s_grid_shape = list(map(np.size, grids["states"])) - # s_bin_shape = tuple(dim + 1 for dim in s_grid_shape) - a_grid_shape = list(map(np.size, grids["actions"])) - total_gridpoints = np.prod(s_grid_shape) * np.prod(a_grid_shape) - if verbose > 0: - print("computing a total of " + str(total_gridpoints) + " points.") + return TransitionResult( + q_map=q_map, + q_fail=q_fail, + q_on_grid=q_on_grid, + q_reached=q_reached, + ) - # initilize pool - pool = mp.Pool() - # create list of args - SA = list(it.product(*grids["states"], *grids["actions"])) - # for sa in SA: - # print(sa[1],end=' in a, and ') - p = p_map.p.copy() - args = [p_map.sa2xp(sa, p) for sa in SA] - # for ar in args: - # print(ar[1]['angle_of_attack'], end='') - # print(" and " + str(ar[0][1])) - # start pool with starmap - results = pool.starmap(p_map, args) - pool.close() - - # for r in results: - # print(r[0]) - - # do the standard stuff (put into bins etc.) - - Q_map = np.zeros((total_gridpoints, 1), dtype=int) - Q_F = np.zeros((total_gridpoints, 1), dtype=bool) - if keep_coords: - Q_reached = np.zeros((len(grids["states"]), total_gridpoints)) - - if check_grid: - Q_on_grid = np.copy(Q_F) # HACK: keep track of wether you are in a bin - - # TODO: no need to use it.product here, since we just use the index - for idx, sa in enumerate( - np.array(list(it.product(*grids["states"], *grids["actions"]))) - ): - x_next, failed = results[idx] - s_next = p_map.xp2s(x_next, p) - if keep_coords: - Q_reached[:, idx] = s_next - if not failed: - if check_grid: - for sdx, sval in enumerate(np.atleast_1d(s_next)): - if ~np.isin(sval, grids["states"][sdx]): - Q_map[idx] = digitize_s( - s_next, grids["states"], shape=s_grid_shape, to_bin=False - ) - break - else: - Q_on_grid[idx] = True - Q_map[idx] = digitize_s( - s_next, grids["states"], shape=s_grid_shape, to_bin=False - ) - else: - Q_map[idx] = digitize_s( - s_next, grids["states"], shape=s_grid_shape, to_bin=False - ) - # check if s happens to be right on the grid-point - else: - Q_F[idx] = 1 +def compute_Q_map( + grids, + p_map, + verbose=0, + check_grid=False, + keep_coords=False, + parallel=True, + bin_mode="bin", +): + """Compute the transition map of a system.""" - Q_map = Q_map.reshape(s_grid_shape + a_grid_shape) - Q_F = Q_F.reshape(s_grid_shape + a_grid_shape) + if bin_mode not in {"bin", "nearest"}: + raise ValueError(f"Unsupported bin_mode '{bin_mode}'") - deliver = [Q_map, Q_F] + total_gridpoints = _total_gridpoints(grids) + if verbose > 0: + print("computing a total of " + str(total_gridpoints) + " points.") + + progress_mod = None + if verbose > 1 and total_gridpoints >= 10: + progress_mod = max(total_gridpoints // 10, 1) - if check_grid: - deliver.append(Q_on_grid.reshape(s_grid_shape + a_grid_shape)) - if keep_coords: - deliver.append(Q_reached) + if parallel: + state_actions = list(it.product(*grids["states"], *grids["actions"])) + base_params = p_map.p.copy() + args = [p_map.sa2xp(sa, base_params) for sa in state_actions] + with mp.Pool() as pool: + results = pool.starmap(p_map, args) - return deliver + records = [] + for idx, ((x_next, failed), (_, params)) in enumerate(zip(results, args)): + if progress_mod and idx % progress_mod == 0: + print(".", end=" ") + s_next = p_map.xp2s(x_next, params) + records.append((np.atleast_1d(s_next), bool(failed))) + else: + state_actions = it.product(*grids["states"], *grids["actions"]) + + def record_iter(): + for idx, state_action in enumerate(state_actions): + if progress_mod and idx % progress_mod == 0: + print(".", end=" ") + x, params = p_map.sa2xp(state_action, p_map.p) + x_next, failed = p_map(x, params) + s_next = p_map.xp2s(x_next, params) + yield np.atleast_1d(s_next), bool(failed) + + records = record_iter() + + result = _assemble_transition( + grids, + records, + total_gridpoints, + check_grid=check_grid, + keep_coords=keep_coords, + bin_mode=bin_mode, + ) + return result diff --git a/viability/viability2D.py b/viability/viability2D.py new file mode 100644 index 0000000..24cee0b --- /dev/null +++ b/viability/viability2D.py @@ -0,0 +1,74 @@ +""" +Legacy 2D viability routines. + +These helpers are kept for reference to understand the basic idea behind the algorithms, but may contain bugs and are not maintained. +""" + +import itertools as it +import numpy as np + + +def compute_Q_2D(s_grid, a_grid, p_map): + """Compute the transition map of a system with 1D state and 1D action.""" + + Q_map = np.zeros((s_grid.size * a_grid.size, 1)) + Q_F = np.zeros((s_grid.size * a_grid.size, 1)) + + n = len(s_grid) * len(a_grid) + for idx, state_action in enumerate(it.product(s_grid, a_grid)): + if idx % (n / 10) == 0: + print(".", end=" ") + x, p = p_map.sa2xp(state_action, p_map.p) + x_next, failed = p_map(x, p) + if not failed: + s_next = p_map.xp2s(x_next, p) + # note: Q_map is implicitly already excluding transitions that + # move straight to a failure. While this is not equivalent to the + # algorithm in the paper, for our systems it is a bit faster + Q_map[idx] = s_next + else: + Q_F[idx] = 1 + + return ( + Q_map.reshape((s_grid.size, a_grid.size)), # only 2D + Q_F.reshape((s_grid.size, a_grid.size)), + ) + + +def project_Q2S_2D(Q): + S = np.zeros((Q.shape[0], 1)) + for sdx, _ in enumerate(S): + if sum(Q[sdx, :]) > 0: + S[sdx] = 1 + return S + + +def is_outside_2D(s, S_V, s_grid): + """Given a level set S, check if s is inside S or not.""" + if sum(S_V) <= 1: + return True + + s_min, s_max = s_grid[S_V > 0][[0, -1]] + if s > s_max or s < s_min: + return True + return False + + +def compute_QV_2D(Q_map, grids, Q_V=None): + """Compute viable sets for the 2D specialisation.""" + + if Q_V is None: + Q_V = np.copy(Q_map) + Q_V[Q_V > 0] = 1 + + S_old = np.zeros((Q_V.shape[0], 1)) + S_V = project_Q2S_2D(Q_V) + while np.array_equal(S_V, S_old): + for qdx, is_viable in enumerate(np.nditer(Q_V)): + if is_viable: + if is_outside_2D(Q_map[qdx], S_V, grids["states"]): + Q_V[qdx] = 0 + S_old = S_V + S_V = project_Q2S_2D(Q_V) + + return Q_V, S_V