diff --git a/popcor/examples/rotation_lifter.py b/popcor/examples/rotation_lifter.py index 6a7cf97..7d824ed 100644 --- a/popcor/examples/rotation_lifter.py +++ b/popcor/examples/rotation_lifter.py @@ -1,5 +1,6 @@ """RotationAveraging class for rotation averaging and synchronization problems.""" +import warnings from typing import Any, Dict, List, Tuple import numpy as np @@ -8,6 +9,7 @@ from scipy.spatial.transform import Rotation from popcor.base_lifters import StateLifter +from popcor.utils.plotting_tools import LINESTYLES METHOD: str = "CG" SOLVER_KWARGS: dict = dict( @@ -174,8 +176,16 @@ def get_theta(self, x: np.ndarray) -> np.ndarray: return C_flat.reshape((self.d, self.n_rot * self.d), order="F") elif self.level == "bm": R0 = x[: self.d, : self.d].T + + I = R0.T @ R0 + scale_R = I[0, 0] + np.testing.assert_allclose(np.diag(I), scale_R) + warnings.warn(f"R0 is scaled by {scale_R}", UserWarning) + Ri = np.array(x[self.d : (self.n_rot + 1) * self.d, : self.d]).T - Ri_world = R0.T @ Ri + np.testing.assert_allclose(np.diag(Ri.T @ Ri), scale_R) + Ri_world = R0.T @ Ri / scale_R + np.testing.assert_allclose(np.diag(Ri_world.T @ Ri_world), 1.0) return Ri_world else: raise ValueError(f"Unknown level {self.level} for RotationLifter") @@ -448,14 +458,15 @@ def plot(self, estimates: Dict[str, np.ndarray] = {}) -> Tuple[Any, Any]: ) label = None - linestyles = itertools.cycle(["--", "-.", ":"]) + linestyles = itertools.cycle(LINESTYLES) for label, theta in estimates.items(): + ls = next(linestyles) for i in range(self.n_rot): plot_frame( ax=ax, theta=theta[:, i * self.d : (i + 1) * self.d], label=label, - ls=next(linestyles), + ls=ls, scale=1.0, marker="", r_wc_w=np.hstack([i * 2.0] + [0.0] * (self.d - 1)), # type: ignore diff --git a/popcor/utils/plotting_tools.py b/popcor/utils/plotting_tools.py index 94f1633..bdf129b 100644 --- a/popcor/utils/plotting_tools.py +++ b/popcor/utils/plotting_tools.py @@ -6,6 +6,19 @@ from popcor.utils.geometry import get_C_r_from_theta +LINESTYLES = [ + "-", # solid + "--", # dashed + "-.", # dash-dot + ":", # dotted + (0, (5, 2)), # long dash, short gap + (0, (1, 1)), # very fine dashed (dot style) + (0, (3, 5, 1, 5)), # dash, gap, dot, gap + (0, (5, 10)), # widely spaced dashes + (0, (3, 1, 1, 1, 1, 1)), # dash-dot-dot + (0, (2, 4, 6, 4)), # mixed dash lengths +] + def import_plt(): import shutil diff --git a/tests/test_rotation_lifter.py b/tests/test_rotation_lifter.py index 42668c7..931225d 100644 --- a/tests/test_rotation_lifter.py +++ b/tests/test_rotation_lifter.py @@ -135,16 +135,17 @@ def test_solve_local(): print("done") -def test_solve_sdp(): +@pytest.mark.parametrize("level", ["no", "bm"]) +def test_solve_sdp(level): """Solve the SDP relaxation and compare the recovered rotations to ground truth.""" d = 3 n_abs = 1 - n_rot = 4 + n_rot = 5 estimates = {} - for noise in [0.0, 0.2]: + for noise in [0, 1e-3, 0.2]: np.random.seed(2) lifter = RotationLifter( - d=d, n_abs=n_abs, n_rot=n_rot, sparsity="chain", level="no" + d=d, n_abs=n_abs, n_rot=n_rot, sparsity="chain", level=level ) y = lifter.simulate_y(noise=noise) @@ -160,13 +161,23 @@ def test_solve_sdp(): X, info = solve_sdp(Q, constraints, verbose=False) - x, info_rank = rank_project(X, p=1) + if lifter.level == "no": + x, info_rank = rank_project(X, p=1) + theta_sdp = lifter.get_theta(x) + else: + X, info_rank = rank_project(X, p=lifter.d) + theta_sdp = lifter.get_theta(X) + print(f"EVR at noise {noise:.2f}: {info_rank['EVR']:.2e}") if noise <= 1: assert info_rank["EVR"] > 1e8 - theta_sdp = lifter.get_theta(x) + + error = lifter.get_error(theta_sdp) if noise == 0.0: - np.testing.assert_allclose(theta_sdp, lifter.theta, atol=1e-5) + np.testing.assert_allclose(theta_sdp, lifter.theta, atol=1e-8) + assert error < 1e-10 + elif noise < 1e-2: + assert error < 1e-2 estimates.update({"init gt": lifter.theta, f"SDP noise {noise:.2f}": theta_sdp}) @@ -178,7 +189,8 @@ def test_solve_sdp(): if __name__ == "__main__": - test_solve_sdp() + test_solve_sdp(level="no") + test_solve_sdp(level="bm") test_solve_local() test_measurements(d=2, level="no")